diff --git a/.github/packaging/vllm_reqs.txt b/.github/packaging/vllm_reqs.txt index aad2b28bd..c7d38ec64 100644 --- a/.github/packaging/vllm_reqs.txt +++ b/.github/packaging/vllm_reqs.txt @@ -7,20 +7,20 @@ # See the file .github/workflows/gpu_test.yaml for an E2E forge installation using this approach. # TODO: this should be done way less hackily aiohappyeyeballs==2.6.1 -aiohttp==3.13.0 +aiohttp==3.13.1 aiosignal==1.4.0 annotated-types==0.7.0 anyio==4.11.0 astor==0.8.1 async-timeout==5.0.1 attrs==25.4.0 -blake3==1.0.7 -cachetools==6.2.0 +blake3==1.0.8 +cachetools==6.2.1 cbor2==5.7.0 certifi==2025.10.5 cffi==2.0.0 -charset-normalizer==3.4.3 -click==8.3.0 +charset-normalizer==3.4.4 +click==8.2.1 cloudpickle==3.1.1 cmake==4.1.0 compressed-tensors==0.10.2 @@ -33,7 +33,7 @@ dnspython==2.8.0 einops==0.8.1 email-validator==2.3.0 exceptiongroup==1.3.0 -fastapi==0.118.3 +fastapi==0.119.0 fastapi-cli==0.0.13 fastapi-cloud-cli==0.3.1 fastrlock==0.8.3 @@ -47,10 +47,10 @@ httpcore==1.0.9 httptools==0.7.1 httpx==0.28.1 huggingface-hub==0.35.3 -idna==3.10 +idna==3.11 interegular==0.3.3 Jinja2==3.1.6 -jiter==0.11.0 +jiter==0.11.1 jsonschema==4.25.1 jsonschema-specifications==2025.9.1 lark==1.2.2 @@ -58,70 +58,69 @@ llguidance==0.7.30 llvmlite==0.44.0 lm-format-enforcer==0.10.12 markdown-it-py==4.0.0 -MarkupSafe==3.0.2 +MarkupSafe==2.1.5 mdurl==0.1.2 mistral_common==1.8.5 mpmath==1.3.0 msgpack==1.1.2 msgspec==0.19.0 multidict==6.7.0 -networkx==3.4.2 +networkx==3.3 ninja==1.13.0 numba==0.61.2 numpy==2.2.6 -nvidia-cublas-cu12==12.9.1.4 -nvidia-cuda-cupti-cu12==12.9.79 -nvidia-cuda-nvrtc-cu12==12.9.86 -nvidia-cuda-runtime-cu12==12.9.79 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 nvidia-cudnn-cu12==9.10.2.21 -nvidia-cufft-cu12==11.4.1.4 -nvidia-cufile-cu12==1.14.1.1 -nvidia-curand-cu12==10.3.10.19 -nvidia-cusolver-cu12==11.7.5.82 -nvidia-cusparse-cu12==12.5.10.65 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 nvidia-cusparselt-cu12==0.7.1 nvidia-nccl-cu12==2.27.5 -nvidia-nvjitlink-cu12==12.9.86 +nvidia-nvjitlink-cu12==12.8.93 nvidia-nvshmem-cu12==3.3.20 -nvidia-nvtx-cu12==12.9.79 +nvidia-nvtx-cu12==12.8.90 openai==1.90.0 opencv-python-headless==4.12.0.88 outlines_core==0.2.10 packaging==25.0 partial-json-parser==0.2.1.1.post6 -pillow==11.3.0 +pillow==12.0.0 prometheus-fastapi-instrumentator==7.1.0 prometheus_client==0.23.1 propcache==0.4.1 -protobuf==6.32.1 +protobuf==6.33.0 psutil==7.1.0 py-cpuinfo==9.0.0 pybase64==1.4.2 pycountry==24.6.1 pycparser==2.23 -pydantic==2.12.0 +pydantic==2.12.3 pydantic-extra-types==2.10.6 -pydantic_core==2.41.1 +pydantic_core==2.41.4 Pygments==2.19.2 python-dotenv==1.1.1 python-json-logger==4.0.0 python-multipart==0.0.20 -pytorch-triton==3.4.0+gitf7888497 PyYAML==6.0.3 pyzmq==27.1.0 -ray==2.49.2 -referencing==0.36.2 +ray==2.50.0 +referencing==0.37.0 regex==2025.9.18 requests==2.32.5 rich==14.2.0 rich-toolkit==0.15.1 -rignore==0.7.0 +rignore==0.7.1 rpds-py==0.27.1 safetensors==0.6.2 scipy==1.15.3 sentencepiece==0.2.1 -sentry-sdk==2.41.0 -setuptools-scm==9.2.0 +sentry-sdk==2.42.0 +setuptools-scm==9.2.1 shellingham==1.5.4 sniffio==1.3.1 soundfile==0.13.1 @@ -131,17 +130,17 @@ sympy==1.14.0 tiktoken==0.12.0 tokenizers==0.22.1 tomli==2.3.0 -torch==2.9.0.dev20250905+cu129 +torch==2.9.0+cu128 tqdm==4.67.1 -transformers==4.57.0 -triton==3.4.0 +transformers==4.57.1 +triton==3.5.0 typer==0.19.2 typing-inspection==0.4.2 typing_extensions==4.15.0 urllib3==2.5.0 uvicorn==0.37.0 -uvloop==0.21.0 -watchfiles==1.1.0 +uvloop==0.22.1 +watchfiles==1.1.1 websockets==15.0.1 xgrammar==0.1.21 yarl==1.22.0 diff --git a/.github/workflows/build_vllm.yaml b/.github/workflows/build_vllm.yaml index 0e8279ac4..4a90ecd88 100644 --- a/.github/workflows/build_vllm.yaml +++ b/.github/workflows/build_vllm.yaml @@ -12,15 +12,15 @@ permissions: jobs: build: - name: forge-cu129-nightly - uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + name: forge-cu128-nightly + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@vllm-push strategy: fail-fast: false with: repository: meta-pytorch/forge ref: "" test-infra-repository: pytorch/test-infra - test-infra-ref: main + test-infra-ref: vllm-push run-smoke-test: false wheel-nightly-policy: gha_workflow_preview_build_wheels wheel-upload-path: whl/preview/forge/ @@ -31,13 +31,13 @@ jobs: { "python_version": "3.10", "gpu_arch_type": "cpu", - "gpu_arch_version": "12.9", - "desired_cuda": "cu129", - "container_image": "pytorch/manylinux2_28-builder:cuda12.9", + "gpu_arch_version": "12.8", + "desired_cuda": "cu128", + "container_image": "pytorch/manylinux2_28-builder:cuda12.8", "package_type": "manywheel", - "build_name": "manywheel-py3_10-cuda12_9", + "build_name": "manywheel-py3_10-cuda12_8", "validation_runner": "linux.12xlarge.memory", - "installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129", + "installation": "pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128", "channel": "nightly", "upload_to_base_bucket": "no", "stable_version": "2.8.0", diff --git a/.meta/mast/README.md b/.meta/mast/README.md index e6f64d739..235f48fcf 100644 --- a/.meta/mast/README.md +++ b/.meta/mast/README.md @@ -21,7 +21,7 @@ The `env_setup.sh` script will automatically: chmod +x .meta/mast/env_setup.sh # Run the setup -./.meta/mast/env_setup.sh +source .meta/mast/env_setup.sh ``` @@ -44,3 +44,82 @@ The launch script will automatically: - Launch the MAST job with the specified config You can run it from anywhere, and it will figure out the correct paths. + + +## How MAST Launcher Works + +The MAST launcher uses a two-stage architecture to run training jobs: + +### Stage 1: Detached Mode (Local Machine) + +When you run `./.meta/mast/launch.sh`, the `main.py` script starts in **detached mode**: + +1. The launcher creates a MAST job with all the worker roles (GPU hosts) +2. It also creates a special **client role** - a CPU-only role that will run inside MAST +3. The client role's entrypoint is set to `client_bootstrap.sh` +4. All CLI arguments you pass are forwarded to the client role + +At this point, the job is submitted to MAST and your local script exits. Everything now runs in the cluster. + +### Stage 2: Remote Mode (Inside MAST) + +The `client_bootstrap.sh` script runs inside the MAST client role and: + +1. Calls `main.py` again, but now with `--mode=remote` +2. In **remote mode**, the script: + - Mounts the OilFS workspace + - Initializes the provisioner to connect to worker roles + - Runs the actual training workload (e.g., GRPO) + +This architecture allows the entire training workflow to run inside MAST without requiring a persistent connection from your local machine. + +### Key Files + +- **`main.py`**: Entry point that handles both detached and remote modes +- **`client_bootstrap.sh`**: Entrypoint for the client role in MAST +- **`launcher.py`**: Creates the MAST job specification and handles role configuration + + +## Managing HuggingFace Models in MAST + +### The Problem: No Internet Access + +MAST compute nodes cannot access the internet, which means they cannot download models directly from HuggingFace. To work around this, we store all HuggingFace models and cache data on OilFS at `/mnt/wsfuse/teamforge/hf`, which is accessible from MAST. + +### Solution: Two-Step Process + +You need to perform both steps below to ensure models work correctly in MAST: + +#### 1. Download Model Weights to OilFS + +First, download the model weights directly to the OilFS path. This should be done from a machine with internet access (like your devserver): + +```bash +# Set HF_HOME to the OilFS path +export HF_HOME=/mnt/wsfuse/teamforge/hf + +# Download the model (replace with your desired model) +huggingface-cli download Qwen/Qwen3-8B --local-dir /mnt/wsfuse/teamforge/hf_artifacts/qwen3_8b +``` + +#### 2. Hydrate the HuggingFace Cache + +After downloading the weights, you need to hydrate the HuggingFace cache so that the transformers library can find the model metadata: + +```bash +# Set HF_HOME to the OilFS path +export HF_HOME=/mnt/wsfuse/teamforge/hf + +# Hydrate the cache for the model +python .meta/mast/hydrate_cache.py --model-id Qwen/Qwen3-8B +``` + +This ensures that when MAST runs with `HF_HUB_OFFLINE=1`, the transformers library can locate all necessary files from the cache. + +### Directory Structure + +Both cache and model files are stored under: +- **Cache**: `/mnt/wsfuse/teamforge/hf` (set via `HF_HOME`) +- **Model weights**: `/mnt/wsfuse/teamforge/hf/` + +Make sure your MAST config files point to the correct paths in `hf_artifacts`. diff --git a/.meta/mast/client_bootstrap.sh b/.meta/mast/client_bootstrap.sh new file mode 100755 index 000000000..1e73a7e3d --- /dev/null +++ b/.meta/mast/client_bootstrap.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Bootstrap script for the MAST client role +# This script sets up the environment and launches the client training script + +set -eEx + +LIBCUDA="/usr/local/fbcode/platform010/lib/libcuda.so" +if [ -f "$LIBCUDA" ]; then + export LIBCUDA_DIR="${LIBCUDA%/*}" + export TRITON_LIBCUDA_PATH="$LIBCUDA_DIR" + export LD_PRELOAD="$LIBCUDA:/usr/local/fbcode/platform010/lib/libnvidia-ml.so${PRELOAD_PATH:+:$PRELOAD_PATH}" +fi + +# Also preload put path to torch libs as for monarch dev workflow we dont +# install it into the env so we need to make sure the binaries can find +# libtorch and friends on mast and the rpaths set during dev install will +# be wrong on mast. +export LD_LIBRARY_PATH="${CONDA_DIR}/lib:${CONDA_DIR}/lib/python3.10/site-packages/torch/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}" +export PYTHONPATH="${PYTHONPATH:+$PYTHONPATH:}$TORCHX_RUN_PYTHONPATH" + +# shellcheck disable=SC1091 +if [ -n "$CONDA_PREFIX" ]; then + echo "A conda environment is already activated: $CONDA_DEFAULT_ENV" +else + # Disable command printing to avoid log spew. + set +x + source "${CONDA_DIR}/bin/activate" + # Re-enable command printing after conda activation. + set -x +fi + +if [ -z "$WORKSPACE_DIR" ] || [ ! -d "$WORKSPACE_DIR" ]; then + WORKSPACE_DIR="$CONDA_PREFIX" +fi + +cd "$WORKSPACE_DIR/forge" + +export WANDB_MODE=offline +export HF_HUB_OFFLINE=1 +export MONARCH_HOST_MESH_V1_REMOVE_ME_BEFORE_RELEASE=1 +export TORCHSTORE_RDMA_ENABLED=1 +export HF_HOME=/mnt/wsfuse/teamforge/hf + +# Execute the client training script with all passed arguments +exec python -X faulthandler .meta/mast/main.py "$@" diff --git a/.meta/mast/env_setup.sh b/.meta/mast/env_setup.sh index feef663b7..323e2febe 100755 --- a/.meta/mast/env_setup.sh +++ b/.meta/mast/env_setup.sh @@ -7,10 +7,9 @@ # LICENSE file in the root directory of this source tree. # setup_forge_env.sh - Setup conda environment and install forge with mounting -set -e # Exit on any error # Configuration -CONDA_ENV_NAME="forge:stable" +CONDA_ENV_NAME="forge:41468b33a03eaf2bf5b44517f418028a" # Colors for output RED='\033[0;31m' @@ -109,8 +108,6 @@ fi # Define paths FBSOURCE_PATH="/data/users/$USER/fbsource" CONDA_SCRIPT_PATH="$FBSOURCE_PATH/genai/xlformers/dev/xl_conda.sh" -FORGE_BASE_DIR="/data/users/$USER" -FORGE_REPO_DIR="$FORGE_BASE_DIR/forge" # Workspace URL for mounting WORKSPACE_URL="ws://ws.ai.pci0ai/genai_fair_llm" @@ -143,63 +140,12 @@ fi log_info "Conda environment activated successfully" -# Step 3: Create and navigate to forge base directory -log_info "Step 3: Setting up forge directory..." -if [ ! -d "$FORGE_BASE_DIR" ]; then - log_info "Creating forge base directory: $FORGE_BASE_DIR" - mkdir -p "$FORGE_BASE_DIR" -fi - -cd "$FORGE_BASE_DIR" -log_info "Changed to directory: $(pwd)" -# Step 4: Clone or update forge repository -log_info "Step 4: Setting up forge git repository..." -if [ -d "$FORGE_REPO_DIR" ]; then - log_warn "Forge repository already exists at: $FORGE_REPO_DIR" - cd "$FORGE_REPO_DIR" - - if [ -d ".git" ]; then - log_info "Updating existing repository..." - git fetch origin - if [ $? -eq 0 ]; then - log_info "Repository updated successfully" - else - log_warn "Failed to fetch updates, continuing with existing code" - fi - else - log_error "Directory exists but is not a git repository" - log_info "Removing directory and cloning fresh..." - cd "$FORGE_BASE_DIR" - rm -rf "$FORGE_REPO_DIR" - git clone git@github.com:meta-pytorch/forge.git - if [ $? -ne 0 ]; then - log_error "Failed to clone forge repository" - exit 1 - fi - cd "$FORGE_REPO_DIR" - fi -else - log_info "Cloning forge repository..." - git clone git@github.com:meta-pytorch/forge.git - if [ $? -ne 0 ]; then - log_error "Failed to clone forge repository" - log_error "Please ensure:" - log_error "1. You have SSH access to github.com" - log_error "2. Your SSH key is added to GitHub" - log_error "3. You have access to meta-pytorch/forge repository" - exit 1 - fi - cd "$FORGE_REPO_DIR" -fi - -log_info "Current directory: $(pwd)" - -# Step 5: Install torchtitan -log_info "Step 5: Installing torchtitan..." +# Step 3: Install torchtitan +log_info "Step 3: Installing torchtitan..." # Source versions.sh to get the pinned commit -VERSIONS_FILE="$FORGE_REPO_DIR/assets/versions.sh" +VERSIONS_FILE="assets/versions.sh" if [ -f "$VERSIONS_FILE" ]; then log_info "Sourcing version information from: $VERSIONS_FILE" source "$VERSIONS_FILE" @@ -225,8 +171,8 @@ else exit 1 fi -# Step 5.5: Apply monarch torch import hack -log_info "Step 5.5: Applying monarch torch import hack..." +# Step 3.5: Apply monarch torch import hack +log_info "Step 3.5: Applying monarch torch import hack..." MONARCH_INIT="$CONDA_PREFIX/lib/python3.10/site-packages/monarch/__init__.py" if [ -f "$MONARCH_INIT" ]; then @@ -259,8 +205,26 @@ else log_warn "Skipping monarch torch import hack (monarch may not be installed yet)" fi -# Step 6: Install forge package -log_info "Step 6: Installing forge package..." +# Step 4: Check for existing build directory and warn user +log_info "Step 4: Checking for existing build directory..." +if [ -d "build" ]; then + log_warn "Detected existing build/ directory at: $(pwd)/build" + log_warn "This directory may contain artifacts from a previous pip installation" + log_warn "that could interfere with the current installation." + log_warn "If you encounter issues, manually remove it with: rm -rf build" + echo "" + read -p "$(echo -e ${YELLOW}Do you want to continue anyway? [y/N]:${NC} )" -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + log_info "Installation cancelled by user" + log_info "You can manually remove the build/ directory with: rm -rf build" + exit 0 + fi + log_warn "Continuing with existing build/ directory. Things might go wrong!" +fi + +# Step 5: Install forge package +log_info "Step 5: Installing forge package..." pip install --no-deps --force-reinstall . if [ $? -ne 0 ]; then log_error "Failed to install forge package" @@ -298,7 +262,11 @@ pip list | grep -E "(forge|monarch)" || log_warn "No forge/monarch packages foun log_info "Environment setup complete! You can now run your scripts." log_info "Mounted workspace available at: /mnt/wsfuse" -# Step 6: Ask user to deactivate and activate conda env conda environment +log_info "Unsetting CUDA_HOME and overwriting the LD_LIBRARY_PATH" +unset CUDA_HOME +export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib + +# Step 6: Ask user to test echo "" log_info "Installation completed successfully!" echo "" diff --git a/.meta/mast/hydrate_cache.py b/.meta/mast/hydrate_cache.py new file mode 100644 index 000000000..6289aee8a --- /dev/null +++ b/.meta/mast/hydrate_cache.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""This is convenience script meant for hydrating the HuggingFace cache. + +This is meant for downloading the model weights and tokenizer to the cache, i.e. for +OilFS. + +Example: + +python .meta/mast/hydrate_cache.py --model-id Qwen/Qwen3-32B + +""" +import argparse +import os +import sys + +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def main(): + parser = argparse.ArgumentParser( + description="Hydrate HuggingFace cache for a specific model" + ) + parser.add_argument( + "--model-id", + type=str, + required=True, + help="HuggingFace model ID (e.g., Qwen/Qwen3-8B)", + ) + args = parser.parse_args() + + # Ensure HF_HOME is set + hf_home = os.environ.get("HF_HOME") + if not hf_home: + print( + "ERROR: HF_HOME environment variable must be set. " + "You will likely want to run export HF_HOME=/mnt/wsfuse/teamforge/hf." + ) + sys.exit(1) + + print(f"Using HF_HOME: {hf_home}") + print(f"Downloading {args.model_id}...") + + # This will pull tokenizer + config + all weight shards + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(args.model_id, trust_remote_code=True) + + print("Download complete. Cache hydrated.") + + +if __name__ == "__main__": + main() diff --git a/.meta/mast/launch.sh b/.meta/mast/launch.sh index 46da56d12..2ece4e58e 100755 --- a/.meta/mast/launch.sh +++ b/.meta/mast/launch.sh @@ -34,6 +34,12 @@ fi CONFIG_FILE="$1" +# Generate a unique job name +USER=$(whoami) +RANDOM_SUFFIX=$(cat /dev/urandom | tr -dc 'a-z0-9' | fold -w 6 | head -n 1) +JOB_NAME="${USER}-forge-${RANDOM_SUFFIX}" +log_info "Generated job name: $JOB_NAME" + # Get the directory where this script is located SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" @@ -64,5 +70,10 @@ fi log_info "Successfully reinstalled forge package" # Launch the job +CHECKPOINT_FOLDER=/mnt/wsfuse/teamforge/forge_runs/$JOB_NAME log_info "Launching MAST job..." -PYTHONPATH=. python .meta/mast/main.py --config "$CONFIG_FILE" + +# Manually override the relevant checkpoint path(s) +# This unfortunately cannot be done in the YAML itself since this should be +# based on job name... +PYTHONPATH=. python .meta/mast/main.py --job-name $JOB_NAME --config $CONFIG_FILE trainer.checkpoint.folder=${CHECKPOINT_FOLDER} trainer.dcp_path=${CHECKPOINT_FOLDER} diff --git a/.meta/mast/main.py b/.meta/mast/main.py index cd5de0be9..249d5c5c3 100644 --- a/.meta/mast/main.py +++ b/.meta/mast/main.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import argparse import asyncio -import getpass -import uuid +import sys from apps.grpo.main import main as grpo_main from forge.cli.config import parse from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY -from forge.controller.provisioner import init_provisioner +from forge.controller.provisioner import get_or_create_provisioner from forge.types import ( Launcher, @@ -20,32 +20,40 @@ ProvisionerConfig, ServiceConfig, ) +from forge.util.config import parse from omegaconf import DictConfig DEFAULT_CHECKPOINT_FOLDER_KEY = "checkpoint_folder" DEFAULT_CHECKPOINT_FOLDER = "/mnt/wsfuse/teamforge/forge_runs/" -async def main(cfg: DictConfig): - """Main module for launching mast jobs for GRPO training.""" +async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None): + """Main module for launching mast jobs for GRPO training. + + Args: + cfg: Configuration dictionary + mode: "detached" (default) launches MAST job with client in MAST, + "remote" runs training directly (used when client runs in MAST) + extra_args: Additional CLI arguments to pass through to the client + """ if cfg.get(LAUNCHER_KEY, Launcher.MAST.value) != Launcher.MAST.value: raise ValueError("Launcher must be MAST.") - if cfg.get(JOB_NAME_KEY, None) is not None: - # prepend user name to the job to avoid name collision - cfg[JOB_NAME_KEY] = f"{getpass.getuser()}-{cfg[JOB_NAME_KEY]}" - print(f"Overriding mast job name to {cfg[JOB_NAME_KEY]}") + # Job name should already be set from CLI args in __main__ section + # No need to modify it further here + if cfg.get(JOB_NAME_KEY, None) is None: + raise ValueError("Job name is required but not provided") if cfg.get(DEFAULT_CHECKPOINT_FOLDER_KEY, DEFAULT_CHECKPOINT_FOLDER) is not None: # append job_name and guid to CP folder path to avoid path collision if cfg[DEFAULT_CHECKPOINT_FOLDER_KEY] == DEFAULT_CHECKPOINT_FOLDER: - cfg[ - DEFAULT_CHECKPOINT_FOLDER_KEY - ] = f"{cfg[DEFAULT_CHECKPOINT_FOLDER_KEY]}{cfg[JOB_NAME_KEY]}-{uuid.uuid4().hex[:6]}" + cfg[DEFAULT_CHECKPOINT_FOLDER_KEY] = ( + f"{cfg[DEFAULT_CHECKPOINT_FOLDER_KEY]}{cfg[JOB_NAME_KEY]}-{uuid.uuid4().hex[:6]}" + ) print(f"Overriding checkpoint folder to {cfg[DEFAULT_CHECKPOINT_FOLDER_KEY]}") # init mast provisioner - await init_provisioner( + await get_or_create_provisioner( ProvisionerConfig( launcher_config=LauncherConfig( launcher=Launcher(cfg.get(LAUNCHER_KEY, Launcher.MAST.value)), @@ -55,13 +63,50 @@ async def main(cfg: DictConfig): ) ) ) - await grpo_main(cfg) + + if mode == "detached": + # In detached mode, just launch the MAST job with client role included + launcher = MastLauncher( + launcher_config, + detached=True, + extra_args=extra_args or [], + ) + await launcher.launch_mast_job() + print(f"MAST job {launcher.job_name} launched successfully with client role.") + print("The client is running inside MAST and will execute the training.") + else: + # In remote mode, we're already running inside MAST, so mount directory, init provisioner and run training + mount_mnt_directory("/mnt/wsfuse") + await init_provisioner(ProvisionerConfig(launcher_config=launcher_config)) + await grpo_main(cfg) if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + type=str, + default="detached", + choices=["detached", "remote"], + help="Run mode: 'detached' for launching MAST job with client in MAST, 'remote' for running training directly", + ) + parser.add_argument( + "--job-name", + type=str, + default=None, + help="MAST job name (required - generated by launch.sh)", + ) + args, remaining = parser.parse_known_args() + + # Replace sys.argv with remaining args so @parse can work + sys.argv = [sys.argv[0]] + remaining @parse def _main(cfg): - asyncio.run(main(cfg)) + # Override job name from CLI + if args.job_name: + cfg[JOB_NAME_KEY] = args.job_name + print(f"Using job name: {args.job_name}") + asyncio.run(main(cfg, mode=args.mode, extra_args=remaining)) _main() # @parse grabs the cfg from CLI diff --git a/.meta/mast/qwen3_14b_mast.yaml b/.meta/mast/qwen3_14b_mast.yaml index f1f05825f..786f0103c 100644 --- a/.meta/mast/qwen3_14b_mast.yaml +++ b/.meta/mast/qwen3_14b_mast.yaml @@ -3,14 +3,13 @@ # Global configuration group_size: 8 -batch_size: 16 +local_batch_size: 16 # per-device batch size max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-14B" off_by_n: 1 # Off by one by default launcher: mast job_name: forge-qwen3-14b -checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/ # Main loop configuration rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas @@ -26,7 +25,7 @@ metric_logging: # Dataset configuration dataset: - path: "openai/gsm8k" + path: /mnt/wsfuse/teamforge/hf/gsm8k revision: "main" data_split: "train" streaming: true @@ -35,7 +34,7 @@ dataset: # Policy configuration policy: engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs - model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56 + model: /mnt/wsfuse/teamforge/hf/qwen3_14b tensor_parallel_size: 2 pipeline_parallel_size: 1 enforce_eager: false @@ -53,7 +52,7 @@ trainer: model: name: qwen3 flavor: 14B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56 + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_14b optimizer: name: AdamW lr: 1e-5 @@ -61,7 +60,7 @@ trainer: lr_scheduler: warmup_steps: 1 training: - local_batch_size: ${batch_size} + local_batch_size: ${local_batch_size} seq_len: 2048 max_norm: 1.0 steps: 1000000 @@ -79,8 +78,9 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56 + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_14b initial_load_in_hf: true + folder: ${checkpoint_folder} last_save_in_hf: true interval: 500 async_mode: "disabled" @@ -95,7 +95,7 @@ trainer: # Replay buffer configuration replay_buffer: - batch_size: ${batch_size} + batch_size: ${local_batch_size} max_policy_age: ${off_by_n} dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree @@ -104,7 +104,7 @@ ref_model: model: name: qwen3 flavor: 14B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56 + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_14b training: dtype: bfloat16 gc_freq: 1 @@ -119,7 +119,8 @@ ref_model: expert_parallel_degree: 1 checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56 + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_14b + folder: "" initial_load_in_hf: true comm: # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP diff --git a/.meta/mast/qwen3_1_7b_mast.yaml b/.meta/mast/qwen3_1_7b_mast.yaml index 39aaf01ba..4065cf07a 100644 --- a/.meta/mast/qwen3_1_7b_mast.yaml +++ b/.meta/mast/qwen3_1_7b_mast.yaml @@ -3,14 +3,13 @@ # Global configuration group_size: 8 -batch_size: 16 +local_batch_size: 16 # per-device batch size max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default launcher: mast job_name: forge-qwen3-1_7b -checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/ # Main loop configuration rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas @@ -26,7 +25,7 @@ metric_logging: # Dataset configuration dataset: - path: "openai/gsm8k" + path: /mnt/wsfuse/teamforge/hf/gsm8k revision: "main" data_split: "train" streaming: true @@ -35,7 +34,7 @@ dataset: # Policy configuration policy: engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs - model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5 + model: /mnt/wsfuse/teamforge/hf/qwen3_1.7b tensor_parallel_size: 1 pipeline_parallel_size: 1 enforce_eager: false @@ -53,7 +52,8 @@ trainer: model: name: qwen3 flavor: 1.7B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5 + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b + # hf_assets_path: hf://${model} optimizer: name: AdamW lr: 1e-5 @@ -61,7 +61,7 @@ trainer: lr_scheduler: warmup_steps: 1 training: - local_batch_size: ${batch_size} + local_batch_size: ${local_batch_size} seq_len: 2048 max_norm: 1.0 steps: 1000000 @@ -79,8 +79,9 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5 + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b initial_load_in_hf: true + folder: ${checkpoint_folder} last_save_in_hf: true interval: 500 async_mode: "disabled" @@ -95,7 +96,7 @@ trainer: # Replay buffer configuration replay_buffer: - batch_size: ${batch_size} + batch_size: ${local_batch_size} max_policy_age: ${off_by_n} dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree @@ -104,7 +105,8 @@ ref_model: model: name: qwen3 flavor: 1.7B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5 + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b + # hf_assets_path: hf://${model} training: dtype: bfloat16 gc_freq: 1 @@ -119,20 +121,21 @@ ref_model: expert_parallel_degree: 1 checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5 + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b + folder: "" initial_load_in_hf: true # All resource allocations services: policy: procs: ${policy.engine_args.tensor_parallel_size} - num_replicas: 2 + num_replicas: 1 with_gpus: true mesh_name: policy hosts: 1 ref_model: procs: 1 - num_replicas: 2 + num_replicas: 1 with_gpus: true mesh_name: ref_model hosts: 1 diff --git a/.meta/mast/qwen3_32b_mast.yaml b/.meta/mast/qwen3_32b_mast.yaml index 2dc25509d..713c1f784 100644 --- a/.meta/mast/qwen3_32b_mast.yaml +++ b/.meta/mast/qwen3_32b_mast.yaml @@ -3,14 +3,13 @@ # Global configuration group_size: 8 -batch_size: 16 +local_batch_size: 16 # per-device batch size max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-32B" off_by_n: 1 # Off by one by default launcher: mast job_name: forge-qwen3-32b -checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/ # Main loop configuration rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas @@ -26,7 +25,7 @@ metric_logging: # Dataset configuration dataset: - path: "openai/gsm8k" + path: /mnt/wsfuse/teamforge/hf/gsm8k revision: "main" data_split: "train" streaming: true @@ -35,7 +34,7 @@ dataset: # Policy configuration policy: engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs - model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470 + model: /mnt/wsfuse/teamforge/hf/qwen3_32b tensor_parallel_size: 2 pipeline_parallel_size: 1 enforce_eager: false @@ -53,7 +52,7 @@ trainer: model: name: qwen3 flavor: 32B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470 + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_32b optimizer: name: AdamW lr: 1e-5 @@ -61,7 +60,7 @@ trainer: lr_scheduler: warmup_steps: 1 training: - local_batch_size: ${batch_size} + local_batch_size: ${local_batch_size} seq_len: 2048 max_norm: 1.0 steps: 1000000 @@ -79,8 +78,9 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470 + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_32b initial_load_in_hf: true + folder: ${checkpoint_folder} last_save_in_hf: true interval: 500 async_mode: "disabled" @@ -95,7 +95,7 @@ trainer: # Replay buffer configuration replay_buffer: - batch_size: ${batch_size} + batch_size: ${local_batch_size} max_policy_age: ${off_by_n} dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree @@ -104,7 +104,7 @@ ref_model: model: name: qwen3 flavor: 32B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470 + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_32b training: dtype: bfloat16 gc_freq: 1 @@ -119,7 +119,8 @@ ref_model: expert_parallel_degree: 1 checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470 + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_32b + folder: "" initial_load_in_hf: true comm: # TODO: needs to be revisited. causing NCCL timeouts on inits when loading CP diff --git a/.meta/mast/qwen3_4b_mast.yaml b/.meta/mast/qwen3_4b_mast.yaml index 5e74f4b2a..e11e2a25a 100644 --- a/.meta/mast/qwen3_4b_mast.yaml +++ b/.meta/mast/qwen3_4b_mast.yaml @@ -3,14 +3,13 @@ # Global configuration group_size: 8 -batch_size: 16 +local_batch_size: 16 # per-device batch size max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-4B" off_by_n: 1 # Off by one by default launcher: mast job_name: forge-qwen3-4b -checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/ # Main loop configuration rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas @@ -26,7 +25,7 @@ metric_logging: # Dataset configuration dataset: - path: "openai/gsm8k" + path: /mnt/wsfuse/teamforge/hf/gsm8k revision: "main" data_split: "train" streaming: true @@ -35,7 +34,7 @@ dataset: # Policy configuration policy: engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs - model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed + model: /mnt/wsfuse/teamforge/hf/qwen3_4b tensor_parallel_size: 2 pipeline_parallel_size: 1 enforce_eager: false @@ -53,7 +52,8 @@ trainer: model: name: qwen3 flavor: 4B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_4b + # hf_assets_path: hf://${model} optimizer: name: AdamW lr: 1e-5 @@ -61,7 +61,7 @@ trainer: lr_scheduler: warmup_steps: 1 training: - local_batch_size: ${batch_size} + local_batch_size: ${local_batch_size} seq_len: 2048 max_norm: 1.0 steps: 1000000 @@ -79,8 +79,9 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_4b initial_load_in_hf: true + folder: ${checkpoint_folder} last_save_in_hf: true interval: 500 async_mode: "disabled" @@ -95,7 +96,7 @@ trainer: # Replay buffer configuration replay_buffer: - batch_size: ${batch_size} + batch_size: ${local_batch_size} max_policy_age: ${off_by_n} dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree @@ -104,7 +105,8 @@ ref_model: model: name: qwen3 flavor: 4B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_4b + # hf_assets_path: hf://${model} training: dtype: bfloat16 gc_freq: 1 @@ -119,7 +121,8 @@ ref_model: expert_parallel_degree: 1 checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_4b + folder: "" initial_load_in_hf: true # All resource allocations @@ -144,7 +147,7 @@ services: actors: dataset: - procs: 8 + procs: 1 with_gpus: false mesh_name: dataset trainer: diff --git a/.meta/mast/qwen3_8b_mast.yaml b/.meta/mast/qwen3_8b_mast.yaml index 7f5b49af6..0405d767f 100644 --- a/.meta/mast/qwen3_8b_mast.yaml +++ b/.meta/mast/qwen3_8b_mast.yaml @@ -3,14 +3,13 @@ # Global configuration group_size: 8 -batch_size: 16 +local_batch_size: 16 # per-device batch size max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-8B" off_by_n: 1 # Off by one by default launcher: mast job_name: forge-qwen3-8b -checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/ # Main loop configuration rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas @@ -26,7 +25,7 @@ metric_logging: # Dataset configuration dataset: - path: "openai/gsm8k" + path: /mnt/wsfuse/teamforge/hf/gsm8k revision: "main" data_split: "train" streaming: true @@ -35,7 +34,7 @@ dataset: # Policy configuration policy: engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs - model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model + model: /mnt/wsfuse/teamforge/hf/qwen3_8b tensor_parallel_size: 2 pipeline_parallel_size: 1 enforce_eager: false @@ -53,7 +52,7 @@ trainer: model: name: qwen3 flavor: 8B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_8b optimizer: name: AdamW lr: 1e-5 @@ -61,7 +60,7 @@ trainer: lr_scheduler: warmup_steps: 1 training: - local_batch_size: ${batch_size} + local_batch_size: ${local_batch_size} seq_len: 2048 max_norm: 1.0 steps: 1000000 @@ -79,8 +78,9 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_8b initial_load_in_hf: true + folder: ${checkpoint_folder} last_save_in_hf: true interval: 500 async_mode: "disabled" @@ -95,7 +95,7 @@ trainer: # Replay buffer configuration replay_buffer: - batch_size: ${batch_size} + batch_size: ${local_batch_size} max_policy_age: ${off_by_n} dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree @@ -104,7 +104,7 @@ ref_model: model: name: qwen3 flavor: 8B - hf_assets_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model + hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_8b training: dtype: bfloat16 gc_freq: 1 @@ -119,7 +119,8 @@ ref_model: expert_parallel_degree: 1 checkpoint: enable: true - initial_load_path: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model + initial_load_path: /mnt/wsfuse/teamforge/hf/qwen3_8b + folder: "" initial_load_in_hf: true # All resource allocations diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a9bad2ca..0069d0b17 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,3 +47,9 @@ repos: hooks: - id: pydoclint args: [--config=pyproject.toml] + +- repo: https://github.com/fastai/nbdev.git + rev: 2.4.5 + hooks: + - id: nbdev_clean + args: [--clear_all] diff --git a/README.md b/README.md index 2267856ad..5511bc611 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,18 @@ -# image Forge +# image torchforge #### A PyTorch-native agentic RL library that lets you focus on algorithms—not infra. [![Unit Tests](https://github.com/meta-pytorch/forge/actions/workflows/unit_test.yaml/badge.svg?branch=main)](https://github.com/meta-pytorch/forge/actions/workflows/unit_test.yaml?query=branch%3Amain) [![GPU Tests](https://github.com/meta-pytorch/forge/actions/workflows/gpu_test.yaml/badge.svg?branch=main)](https://github.com/meta-pytorch/forge/actions/workflows/gpu_test.yaml?query=branch%3Amain) ## Overview -The primary purpose of the Forge ecosystem is to delineate infra concerns from model concerns thereby making RL experimentation easier. Forge delivers this by providing clear RL abstractions and one scalable implementation of these abstractions. When you need fine-grained control over placement, fault handling/redirecting training loads during a run, or communication patterns, the primitives are there. When you don’t, you can focus purely on your RL algorithm. +The primary purpose of the torchforge ecosystem is to delineate infra concerns from model concerns thereby making RL experimentation easier. torchforge delivers this by providing clear RL abstractions and one scalable implementation of these abstractions. When you need fine-grained control over placement, fault handling/redirecting training loads during a run, or communication patterns, the primitives are there. When you don’t, you can focus purely on your RL algorithm. Key features: - Usability for rapid research (isolating the RL loop from infrastructure) - Hackability for power users (all parts of the RL loop can be easily modified without interacting with infrastructure) - Scalability (ability to shift between async and synchronous training and across thousands of GPUs) -> ⚠️ **Early Development Warning** Forge is currently in an experimental +> ⚠️ **Early Development Warning** torchforge is currently in an experimental > stage. You should expect bugs, incomplete features, and APIs that may change > in future versions. The project welcomes bugfixes, but to make sure things are > well coordinated you should discuss any significant change before starting the @@ -21,7 +21,7 @@ Key features: ## 📖 Documentation (Coming Soon) -View Forge's hosted documentation (coming soon) +View torchforge's hosted documentation (coming soon) ## Tutorials @@ -31,11 +31,11 @@ You can also find our notebook tutorials (coming soon) ### Basic -Forge requires the latest PyTorch nightly with [Monarch](https://github.com/meta-pytorch/monarch), [vLLM](https://github.com/vllm-project/vllm), and [torchtitan](https://github.com/pytorch/torchtitan). For convenience, +torchforge requires the latest PyTorch nightly with [Monarch](https://github.com/meta-pytorch/monarch), [vLLM](https://github.com/vllm-project/vllm), and [torchtitan](https://github.com/pytorch/torchtitan). For convenience, we have pre-packaged these dependencies as wheels in assets/wheels. (Note that the basic install script uses [DNF](https://docs.fedoraproject.org/en-US/quick-docs/dnf/), but could be easily extended to other Linux OS.) -Forge requires the Github CLI (gh) to download a compatible vLLM package. See [here](https://github.com/cli/cli#installation) for gh install instructions before continuting. Please login to gh with your Github account before continuing with `gh auth login`. You may use either https or ssh as the protocol for authentication. +torchforge requires the Github CLI (gh) to download a compatible vLLM package. See [here](https://github.com/cli/cli#installation) for gh install instructions before continuting. Please login to gh with your Github account before continuing with `gh auth login`. You may use either https or ssh as the protocol for authentication. ```bash conda create -n forge python=3.10 @@ -56,11 +56,6 @@ If you need to re-build the wheels for whatever reason, you can do so with: ./scripts/build_wheels.sh ``` -For your information, since the vLLM wheel is too large for GitHub, we uploaded it as a release in the `install.sh` script: -``` -$ gh release create v0.0.0 assets/wheels/vllm-*.whl --title "Forge Wheels v0.0.0" -``` - ## Quick Start To run SFT on a Llama3 8B model, run diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1dbef0b76..bbd64f415 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -24,9 +24,8 @@ from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import RLTrainer -from forge.cli.config import parse from forge.controller.actor import ForgeActor -from forge.controller.provisioner import init_provisioner, shutdown +from forge.controller.provisioner import get_or_create_provisioner, shutdown from forge.data.rewards import MathReward, ThinkingReward from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger @@ -34,6 +33,7 @@ from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse from forge.util.ops import compute_logprobs from monarch.actor import endpoint from omegaconf import DictConfig @@ -298,14 +298,14 @@ async def main(cfg: DictConfig): # ---- Global setups ---- # provisioner = None if cfg.get("provisioner", None) is not None: - provisioner = await init_provisioner( + provisioner = await get_or_create_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) else: - provisioner = await init_provisioner() + provisioner = await get_or_create_provisioner() metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) # ---- Setup services ---- # @@ -346,7 +346,7 @@ async def main(cfg: DictConfig): # TODO: support multiple host meshes trainer_num_procs = cfg.actors.trainer["procs"] trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] - trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + trainer_hosts = provisioner.get_host_mesh.call_one(trainer_host_mesh_name) await ts.initialize( mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), strategy=ts.LocalRankStrategy(), diff --git a/apps/grpo/notebook.ipynb b/apps/grpo/notebook.ipynb new file mode 100644 index 000000000..8d9fbc75a --- /dev/null +++ b/apps/grpo/notebook.ipynb @@ -0,0 +1,669 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "46c66f45-4be3-4674-a870-3849c1048ddb", + "metadata": {}, + "source": [ + "# GRPO for Math (GSM8k)\n", + "\n", + "## Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97d9ca00-92a8-4bd3-9b2b-ab8856f5acce", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "# All rights reserved.\n", + "#\n", + "# This source code is licensed under the BSD-style license found in the\n", + "# LICENSE file in the root directory of this source tree.\n", + "\n", + "import asyncio\n", + "import time\n", + "import uuid\n", + "from dataclasses import dataclass\n", + "from typing import Any, Callable\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import torchstore as ts\n", + "from datasets import load_dataset\n", + "from forge.actors._torchstore_utils import (\n", + " get_dcp_whole_state_dict_key,\n", + " get_param_prefix,\n", + ")\n", + "from forge.actors.generator import Generator as Policy\n", + "from forge.actors.reference_model import ReferenceModel\n", + "from forge.actors.replay_buffer import ReplayBuffer\n", + "from forge.actors.trainer import RLTrainer\n", + "from forge.cli.config import parse\n", + "from forge.controller.actor import ForgeActor\n", + "from forge.controller.provisioner import init_provisioner, shutdown\n", + "from forge.data.rewards import MathReward, ThinkingReward\n", + "from forge.observability.metric_actors import get_or_create_metric_logger\n", + "from forge.observability.metrics import record_metric, Reduce\n", + "from forge.observability.perf_tracker import Tracer\n", + "\n", + "from forge.types import LauncherConfig, ProvisionerConfig\n", + "from forge.util.ops import compute_logprobs\n", + "from monarch.actor import endpoint\n", + "from omegaconf import DictConfig\n", + "from vllm.transformers_utils.tokenizer import get_tokenizer\n", + "\n", + "import os\n", + "os.environ[\"MONARCH_HOSTMESH_V1\"] = \"1\"\n", + "os.environ[\"TORCHSTORE_RDMA_ENABLED\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "id": "34d4319f-e6c9-4f4b-9b92-c572de08f0b2", + "metadata": {}, + "source": [ + "## Define Data Structures" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4a25e9d-e1dd-4ea7-a80c-383a2c04656a", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class Episode:\n", + " # TODO: add adtional layer for multi-turn\n", + " episode_id: str\n", + " request: str\n", + " policy_version: int\n", + " pad_id: int\n", + " request_len: int\n", + " response_len: int\n", + " target: Any | None = None\n", + " # processed data\n", + " response: str | None = None\n", + " request_tokens: list[int] | None = None\n", + " response_tokens: list[int] | None = None\n", + " ref_logprobs: torch.Tensor | None = None\n", + " reward: float | None = None\n", + " advantage: float | None = None\n", + "\n", + " @property\n", + " def request_tensor(self):\n", + " tensor = torch.tensor(self.request_tokens, dtype=torch.long)\n", + " if tensor.shape[0] < self.request_len: # left pad\n", + " diff = self.request_len - tensor.shape[0]\n", + " tensor = F.pad(tensor, (diff, 0), value=self.pad_id)\n", + " return tensor\n", + "\n", + " @property\n", + " def response_tensor(self):\n", + " tensor = torch.tensor(self.response_tokens, dtype=torch.long)\n", + " if tensor.shape[0] < self.response_len: # right pad\n", + " diff = self.response_len - tensor.shape[0]\n", + " tensor = F.pad(tensor, (0, diff), value=self.pad_id)\n", + " return tensor\n", + "\n", + "\n", + "@dataclass\n", + "class Group:\n", + " group_id: str\n", + " episodes: list[Episode]\n", + "\n", + " @classmethod\n", + " def new_group(\n", + " cls,\n", + " group_id: int,\n", + " group_size: int,\n", + " request: str,\n", + " policy_version: int,\n", + " pad_id: int,\n", + " request_len: int,\n", + " response_len: int,\n", + " target: Any = None,\n", + " ):\n", + " episodes = []\n", + " for _ in range(group_size):\n", + " episodes.append(\n", + " Episode(\n", + " episode_id=str(uuid.uuid4()),\n", + " request=request,\n", + " policy_version=policy_version,\n", + " pad_id=pad_id,\n", + " request_len=request_len,\n", + " response_len=response_len,\n", + " target=target,\n", + " )\n", + " )\n", + " return cls(str(group_id), episodes)\n", + "\n", + "\n", + "def collate(batches: list[list[Episode]]):\n", + " inputs = []\n", + " targets = []\n", + " for batch in batches:\n", + " request = [e.request_tensor for e in batch]\n", + " request = torch.stack(request) # [b x s]\n", + "\n", + " response = [e.response_tensor for e in batch]\n", + " response = torch.stack(response) # [b x s]\n", + "\n", + " ref_logprobs = [e.ref_logprobs for e in batch]\n", + " ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]\n", + "\n", + " advantages = [e.advantage for e in batch]\n", + " advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]\n", + "\n", + " pad_id = batch[0].pad_id\n", + " mask = response != pad_id\n", + "\n", + " input = {\"tokens\": torch.cat([request, response], dim=1)}\n", + " target = {\n", + " \"response\": response,\n", + " \"ref_logprobs\": ref_logprobs,\n", + " \"advantages\": advantages,\n", + " \"padding_mask\": mask,\n", + " }\n", + " inputs.append(input)\n", + " targets.append(target)\n", + " return inputs, targets\n", + "\n", + "@dataclass\n", + "class DatasetActor(ForgeActor):\n", + " \"\"\"Actor wrapper for HuggingFace dataset to provide async interface.\"\"\"\n", + "\n", + " path: str = \"openai/gsm8k\"\n", + " revision: str = \"main\"\n", + " data_split: str = \"train\"\n", + " streaming: bool = True\n", + " model: str = \"Qwen/Qwen3-1.7B\"\n", + "\n", + " @endpoint\n", + " def setup(self):\n", + " self._tokenizer = get_tokenizer(self.model)\n", + "\n", + " def gsm8k_transform(sample):\n", + " system_prompt = \"\"\"\n", + " Put all your scratchpad work between and tags.\n", + " Your final answer should be between and tags otherwise it will not be scored.\n", + " \"\"\"\n", + " request: str = sample[\"question\"]\n", + " as_chat = [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": request},\n", + " ]\n", + " formatted_request = self._tokenizer.apply_chat_template(\n", + " as_chat,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " )\n", + " target: str = sample[\"answer\"]\n", + " formatted_target = target.split(\"#### \")[1]\n", + " return {\"request\": formatted_request, \"target\": formatted_target}\n", + "\n", + " ds = load_dataset(\n", + " self.path, self.revision, split=self.data_split, streaming=self.streaming\n", + " )\n", + " ds = ds.map(gsm8k_transform)\n", + " ds = ds.shuffle()\n", + " self._iterator = iter(ds)\n", + "\n", + " @endpoint\n", + " async def sample(self) -> dict[str, str] | None:\n", + " try:\n", + " sample = next(self._iterator)\n", + "\n", + " # Record dataset metrics\n", + " record_metric(\"dataset/sample/count_samples_generated\", 1, Reduce.SUM)\n", + " record_metric(\n", + " \"dataset/sample/avg_sample_len\",\n", + " len(sample[\"request\"]),\n", + " Reduce.MEAN,\n", + " )\n", + "\n", + " return sample\n", + " except StopIteration:\n", + " return None\n", + "\n", + " @endpoint\n", + " async def pad_token(self):\n", + " return self._tokenizer.pad_token_id" + ] + }, + { + "cell_type": "markdown", + "id": "901b3d1d-7eba-4464-b881-48c11ff6e0ef", + "metadata": {}, + "source": [ + "## Define loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "934aca32-0953-4945-9f99-e7b34804443b", + "metadata": {}, + "outputs": [], + "source": [ + "def simple_grpo_loss(\n", + " logits: torch.Tensor,\n", + " response: torch.Tensor,\n", + " ref_logprobs: torch.Tensor,\n", + " advantages: torch.Tensor,\n", + " padding_mask: torch.Tensor,\n", + " beta: float = 0.1,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Example GRPO Loss Function for RLTrainer\n", + " \"\"\"\n", + " logprobs: torch.Tensor = compute_logprobs(logits, response)\n", + "\n", + " # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`\n", + " kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1\n", + " per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages\n", + " per_token_loss = -(per_token_policy_loss - beta * kl)\n", + " loss = (\n", + " ((per_token_loss * padding_mask).sum(dim=1))\n", + " / (padding_mask.sum(dim=1).clamp(min=1.0))\n", + " ).mean()\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "id": "d4f8bbe3-b7ac-4905-b197-f10990f9a104", + "metadata": {}, + "source": [ + "## Define Reward" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "163e98bf-e0f5-4ec3-9690-9839e687f9b3", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class RewardActor(ForgeActor):\n", + " \"\"\"Reward actor that uses a list of scoring functions.\"\"\"\n", + "\n", + " reward_functions: list[Callable]\n", + "\n", + " @endpoint\n", + " async def evaluate_response(self, prompt: str, response: str, target: str) -> float:\n", + " total_rewards = 0.0\n", + " for reward_fn in self.reward_functions:\n", + " reward = reward_fn(prompt, response, target)\n", + " total_rewards += reward\n", + "\n", + " # Get a name for the reward function (works for classes, functions, lambdas)\n", + " reward_fn_name = getattr(\n", + " reward_fn, \"__name__\", reward_fn.__class__.__name__\n", + " )\n", + " # per function reward\n", + " record_metric(\n", + " f\"reward/evaluate_response/sum_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.SUM,\n", + " )\n", + " record_metric(\n", + " f\"reward/evaluate_response/avg_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.MEAN,\n", + " )\n", + " record_metric(\n", + " f\"reward/evaluate_response/std_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.STD,\n", + " )\n", + "\n", + " # avg total reward\n", + " record_metric(\n", + " \"reward/evaluate_response/avg_total_reward\",\n", + " reward,\n", + " Reduce.MEAN,\n", + " )\n", + "\n", + " # count fn calls\n", + " record_metric(\n", + " f\"reward/evaluate_response/count_{reward_fn_name}_calls\",\n", + " 1,\n", + " Reduce.SUM,\n", + " )\n", + "\n", + " avg_reward = total_rewards / len(self.reward_functions)\n", + " return avg_reward\n", + "\n", + "\n", + "@dataclass\n", + "class ComputeAdvantages(ForgeActor):\n", + " \"\"\"Compute advantages for GRPO using reward signals.\"\"\"\n", + "\n", + " @endpoint\n", + " async def compute(self, group: Group) -> list[float]:\n", + " # TODO: add batch processing\n", + " rewards = torch.tensor([[e.reward for e in group.episodes]])\n", + " mean = rewards.mean(1, keepdim=True)\n", + " std = rewards.std(1, keepdim=True)\n", + " advantages = (rewards - mean) / (std + 1e-4)\n", + " return advantages.squeeze(0).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88523484-b414-41db-bd3f-0d8dbf881a85", + "metadata": {}, + "outputs": [], + "source": [ + "async def drop_weights(version: int):\n", + " print(f\"Dropping weights @ version {version}\")\n", + " start_time = time.perf_counter()\n", + " prefix = get_param_prefix(version)\n", + " matching_keys = await ts.keys(prefix)\n", + " # TODO: once we have something like `get_meta()` in torchstore, we can just\n", + " # query the type of the object instead of relying on keys.\n", + " dcp_key = get_dcp_whole_state_dict_key(version)\n", + " if dcp_key in matching_keys:\n", + " dcp_handle = await ts.get(dcp_key)\n", + " dcp_handle.drop()\n", + " for key in matching_keys:\n", + " await ts.delete(key)\n", + " elapsed = time.perf_counter() - start_time\n", + " print(f\"Dropped weights @ version {version}, took {elapsed:.2f} seconds\")" + ] + }, + { + "cell_type": "markdown", + "id": "95d4fef3-180b-4b7e-8871-ecbe113cde72", + "metadata": {}, + "source": [ + "## Setup Services" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c811974-cd6b-40ed-a179-4511a7a6c489", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "from forge.cli.config import resolve_hf_hub_paths\n", + "\n", + "cfg = OmegaConf.load('apps/grpo/qwen3_1_7b.yaml')\n", + "cfg = resolve_hf_hub_paths(cfg)\n", + "OmegaConf.resolve(cfg)\n", + "\n", + "group_size = cfg.group_size # 8\n", + "max_req_tokens = cfg.max_req_tokens # 512\n", + "max_res_tokens = cfg.max_res_tokens # 512\n", + "\n", + "metric_logging_cfg = cfg.get(\"metric_logging\", {\"console\": {\"log_per_rank\": False}})\n", + "mlogger = await get_or_create_metric_logger()\n", + "await mlogger.init_backends.call_one(metric_logging_cfg)\n", + "await ts.initialize(strategy=ts.ControllerStorageVolumes())\n", + "\n", + "dataloader, policy, trainer, replay_buffer, compute_advantages, ref_model, reward_actor = await asyncio.gather(\n", + " DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),\n", + " Policy.options(**cfg.services.policy).as_service(**cfg.policy),\n", + " RLTrainer.options(**cfg.actors.trainer).as_actor(\n", + " **cfg.trainer, loss=simple_grpo_loss\n", + " ),\n", + " ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(\n", + " **cfg.replay_buffer, collate=collate\n", + " ),\n", + " ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),\n", + " ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),\n", + " RewardActor.options(**cfg.services.reward_actor).as_service(\n", + " reward_functions=[MathReward(), ThinkingReward()]\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3f2a305f-b1e2-4eac-803c-71bf3225fed7", + "metadata": {}, + "source": [ + "## Rollout Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1c676fb-2cd6-4c2c-87d4-e1b8cd0b87af", + "metadata": {}, + "outputs": [], + "source": [ + "async def continuous_rollouts():\n", + " rollout_count = 0\n", + " pad_id = await dataloader.pad_token.call_one()\n", + " while True:\n", + " t = Tracer(\"main_perf/continuous_rollouts\")\n", + " t.start()\n", + " sample = await dataloader.sample.call_one()\n", + " if sample is None:\n", + " print(\"Dataloader is empty, exiting continuous rollout\")\n", + " return\n", + "\n", + " t.step(\"data_loading\")\n", + "\n", + " prompt, target = sample[\"request\"], sample[\"target\"]\n", + " responses = await policy.generate.route(prompt)\n", + " # TODO: this shall be part of the responses metadata instead of a separate call\n", + " version = await policy.get_version.route()\n", + "\n", + " t.step(\"policy_generation\")\n", + "\n", + " assert (\n", + " len(responses) > 0\n", + " ), \"Sanity check: Responses should NEVER return empty\"\n", + " assert (\n", + " version := responses[0].generator_version\n", + " ) is not None, \"Response must indicate a version\"\n", + " group = Group.new_group(\n", + " group_id=rollout_count,\n", + " group_size=group_size,\n", + " request=prompt,\n", + " policy_version=version,\n", + " pad_id=pad_id,\n", + " request_len=max_req_tokens,\n", + " response_len=max_res_tokens,\n", + " target=target,\n", + " )\n", + "\n", + " input_ids = torch.ones(\n", + " (group_size, max_req_tokens + max_res_tokens),\n", + " dtype=torch.long,\n", + " device=\"cuda\",\n", + " )\n", + " # Populate episode info and calculate rewards\n", + " for i, (episode, response) in enumerate(zip(group.episodes, responses)):\n", + " episode.request_tokens = response.prompt_ids\n", + " episode.response_tokens = response.token_ids\n", + " episode.response = response.text\n", + " input_ids[i, :max_req_tokens] = episode.request_tensor\n", + " input_ids[i, max_req_tokens:] = episode.response_tensor\n", + " episode.reward = await reward_actor.evaluate_response.route(\n", + " prompt=prompt, response=response.text, target=target\n", + " )\n", + "\n", + " t.step(\"reward_evaluation\")\n", + "\n", + " ref_logprobs = await ref_model.forward.route(\n", + " input_ids, max_req_tokens, return_logprobs=True\n", + " )\n", + " t.step(\"reference_model_calculate_logprobs\")\n", + "\n", + " for i, episode in enumerate(group.episodes):\n", + " episode.ref_logprobs = ref_logprobs[i]\n", + " del ref_logprobs, input_ids\n", + " t.step(\"compute_logprobs\")\n", + "\n", + " # Calculate advantages and add to replay buffer\n", + " advantages = await compute_advantages.compute.call_one(group)\n", + " for episode, advantage in zip(group.episodes, advantages):\n", + " episode.advantage = advantage\n", + " await replay_buffer.add.call_one(episode)\n", + "\n", + " # Log metrics\n", + " rollout_count += 1\n", + " record_metric(\n", + " \"main/continuous_rollouts/count_rollout_iterations\", 1, Reduce.SUM\n", + " )\n", + " t.stop()" + ] + }, + { + "cell_type": "markdown", + "id": "57c316dc-11b5-48ea-8b03-e1bb9d9d1f2b", + "metadata": {}, + "source": [ + "## Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "916a0e79-aded-4ee3-b1a8-db0e772996c9", + "metadata": {}, + "outputs": [], + "source": [ + "async def continuous_training():\n", + " training_step = 0\n", + " restart_tracer = True # Flag to control when to restart tracer\n", + " while True:\n", + " # Restart tracer when needed (initial start or after completing a training step)\n", + " # Otherwise, we cannot measure time waiting for buffer\n", + " if restart_tracer:\n", + " t = Tracer(\"main_perf/continuous_training\")\n", + " t.start()\n", + " restart_tracer = False\n", + "\n", + " batch = await replay_buffer.sample.call_one(\n", + " curr_policy_version=training_step\n", + " )\n", + " if batch is None:\n", + " await asyncio.sleep(0.1)\n", + " else:\n", + " t.step(\"waiting_for_buffer\")\n", + "\n", + " inputs, targets = batch\n", + " await trainer.train_step.call(inputs, targets)\n", + " training_step += 1\n", + " t.step(\"train_step\")\n", + "\n", + " await trainer.push_weights.call(training_step)\n", + " t.step(\"push_weights\")\n", + "\n", + " await policy.update_weights.fanout(training_step)\n", + " update_task = asyncio.create_task(policy.update_weights.fanout(training_step))\n", + " t.step(\"update_weights\")\n", + "\n", + " if training_step >= 2:\n", + " await drop_weights(training_step - 1)\n", + " t.step(\"drop_weights\")\n", + "\n", + " t.stop()\n", + " restart_tracer = True\n", + "\n", + " # Flush metrics every training step to WandB\n", + " await mlogger.flush.call_one(training_step)" + ] + }, + { + "cell_type": "markdown", + "id": "4542863b-59c5-40bc-896c-6d8d44ada00f", + "metadata": {}, + "source": [ + "## Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58194c13-b75e-405d-ab11-18cbe1874d92", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "num_rollout_threads = 1\n", + "num_training_threads = 1\n", + "\n", + "rollout_tasks = [\n", + " asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)\n", + "]\n", + "training_task = asyncio.create_task(continuous_training())\n", + "\n", + "try:\n", + " await asyncio.gather(*rollout_tasks, training_task)\n", + "except KeyboardInterrupt:\n", + " print(\"Training interrupted by user\")\n", + " for rollout_task in rollout_tasks:\n", + " rollout_task.cancel()\n", + " training_task.cancel()" + ] + }, + { + "cell_type": "markdown", + "id": "b4603b80-1f25-49a1-920e-d24f38dfc687", + "metadata": {}, + "source": [ + "## Shutdown" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d74e781-c253-4bd0-929f-bd4ad516ba81", + "metadata": {}, + "outputs": [], + "source": [ + "await mlogger.shutdown.call_one()\n", + "await asyncio.sleep(2)\n", + "\n", + "await asyncio.gather(\n", + " DatasetActor.shutdown(dataloader),\n", + " policy.shutdown(),\n", + " RLTrainer.shutdown(trainer),\n", + " ReplayBuffer.shutdown(replay_buffer),\n", + " ComputeAdvantages.shutdown(compute_advantages),\n", + " ref_model.shutdown(),\n", + " reward_actor.shutdown(),\n", + ")\n", + "await shutdown()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "forge", + "language": "python", + "name": "forge" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 14e4871cf..cc67952fd 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -74,8 +74,9 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: hf://${model} - initial_load_in_hf: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 async_mode: "disabled" @@ -96,6 +97,7 @@ ref_model: flavor: 1.7B hf_assets_path: hf://${model} training: + seq_len: ${trainer.training.seq_len} dtype: bfloat16 gc_freq: 1 compile: diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index e7a0cf509..27c71b3db 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -77,8 +77,9 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: hf://${model} - initial_load_in_hf: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 async_mode: "disabled" @@ -99,6 +100,7 @@ ref_model: flavor: 32B hf_assets_path: hf://${model} training: + seq_len: ${trainer.training.seq_len} dtype: bfloat16 gc_freq: 1 compile: diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 534e5b92a..e19b751d3 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -70,8 +70,9 @@ trainer: disable_loss_parallel: true checkpoint: enable: true - initial_load_path: hf://${model} - initial_load_in_hf: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 async_mode: "disabled" @@ -95,6 +96,7 @@ ref_model: flavor: 8B hf_assets_path: hf://${model} training: + seq_len: ${trainer.training.seq_len} dtype: bfloat16 gc_freq: 1 compile: diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 43a690c1e..44e4485e4 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -45,8 +45,9 @@ parallelism: checkpoint: enable: true - initial_load_path: hf://${model_name} - initial_load_in_hf: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 async_mode: "disabled" diff --git a/apps/sft/main.py b/apps/sft/main.py index 27a8036d4..aa484608e 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -22,12 +22,12 @@ import torch import torchtitan.experiments.forge.train_spec as forge_train_spec -from forge.cli.config import parse from forge.controller import ForgeActor from forge.data.collate import collate_packed from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.util.config import parse from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf diff --git a/apps/sft/qwen3_8b.yaml b/apps/sft/qwen3_8b.yaml index 2ab88bbd3..1c0d5bc8b 100644 --- a/apps/sft/qwen3_8b.yaml +++ b/apps/sft/qwen3_8b.yaml @@ -44,8 +44,9 @@ parallelism: checkpoint: enable: true - initial_load_path: hf://${model_name} - initial_load_in_hf: true + folder: ./checkpoint # The folder to save checkpoints to. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 async_mode: "disabled" diff --git a/assets/versions.sh b/assets/versions.sh index 7c188b0d5..6dfee6761 100644 --- a/assets/versions.sh +++ b/assets/versions.sh @@ -14,6 +14,6 @@ PYTORCH_VERSION="2.9.0.dev20250905" VLLM_BRANCH="v0.10.0" # Commit hashes -MONARCH_COMMIT="195503223b5c2896846171f60ac99dc6868f8f2c" +MONARCH_COMMIT="main" TORCHTITAN_COMMIT="0cfbd0b3c2d827af629a107a77a9e47229c31663" TORCHSTORE_COMMIT="662299faf4fd50ee30bd9aa3f4ce8c0e2db1d310" diff --git a/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md b/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md index fd7c0cf6b..ae74df101 100644 --- a/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md +++ b/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md @@ -88,7 +88,7 @@ graph LR subgraph Services["TorchForge Services (Real Classes)"] direction TB S1["DatasetActor"] - S2["Policy"] + S2["Generator"] S3["RewardActor"] S4["ReferenceModel"] S5["ReplayBuffer"] @@ -290,7 +290,7 @@ TorchForge handles behind the scenes: ### Independent Scaling ```python -from forge.actors.policy import Policy +from forge.actors.generator import Generator as Policy from forge.actors.replay_buffer import ReplayBuffer from forge.actors.reference_model import ReferenceModel from forge.actors.trainer import RLTrainer diff --git a/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md b/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md index 9c8f89bc2..335c5fc5a 100644 --- a/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md +++ b/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md @@ -73,7 +73,7 @@ The service creation automatically handles: - Message routing and serialization ```python -from forge.actors.policy import Policy +from forge.actors.generator import Generator as Policy model = "Qwen/Qwen3-1.7B" @@ -560,7 +560,7 @@ Now let's see how services coordinate in a real training loop: import asyncio import torch -from forge.actors.policy import Policy +from forge.actors.generator import Generator as Policy from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import RLTrainer diff --git a/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md b/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md index a5a28c7a6..8a53566c0 100644 --- a/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md +++ b/docs/source/tutorial_sources/zero-to-forge/3_Monarch_101.md @@ -18,15 +18,15 @@ graph TD end subgraph MonarchLayer["3. Monarch Actor Layer"] - ActorMesh["ActorMesh PolicyActor: 4 instances, Different GPUs, Message passing"] + ActorMesh["ActorMesh Policy Actor: 4 instances, Different GPUs, Message passing"] ProcMesh["ProcMesh: 4 processes, GPU topology 0,1,2,3, Network interconnect"] end subgraph Hardware["4. Physical Hardware"] - GPU0["GPU 0: PolicyActor #1, vLLM Engine, Model Weights"] - GPU1["GPU 1: PolicyActor #2, vLLM Engine, Model Weights"] - GPU2["GPU 2: PolicyActor #3, vLLM Engine, Model Weights"] - GPU3["GPU 3: PolicyActor #4, vLLM Engine, Model Weights"] + GPU0["GPU 0: Policy Actor #1, vLLM Engine, Model Weights"] + GPU1["GPU 1: Policy Actor #2, vLLM Engine, Model Weights"] + GPU2["GPU 2: Policy Actor #3, vLLM Engine, Model Weights"] + GPU3["GPU 3: Policy Actor #4, vLLM Engine, Model Weights"] end Call --> ServiceInterface @@ -154,7 +154,7 @@ await procs.stop() **ActorMesh** is created when you spawn actors across a ProcMesh. Key points: -- **One actor instance per process**: `mesh.spawn("policy", PolicyActor)` creates one PolicyActor in each process +- **One actor instance per process**: `mesh.spawn("policy", Policy)` creates one Policy Actor in each process - **Same constructor arguments**: All instances get the same initialization parameters - **Independent state**: Each actor instance maintains its own state and memory - **Message routing**: You can send messages to one actor or all actors using different methods @@ -162,9 +162,9 @@ await procs.stop() ```python # Simple example: procs = spawn_procs(per_host={"gpus": 4}) # 4 processes -policy_actors = procs.spawn("policy", PolicyActor, model="Qwen/Qwen3-7B") +policy_actors = procs.spawn("policy", Policy, model="Qwen/Qwen3-7B") -# Now you have 4 PolicyActor instances, one per GPU +# Now you have 4 Policy Actor instances, one per GPU # All initialized with the same model parameter ``` @@ -177,29 +177,29 @@ Now the key insight: **TorchForge services are ServiceActors that manage ActorMe ```mermaid graph TD subgraph ServiceCreation["Service Creation Process"] - Call["await PolicyActor.options(num_replicas=4, procs=1).as_service(model='Qwen')"] + Call["await Policy.options(num_replicas=4, procs=1).as_service(model='Qwen')"] ServiceActor["ServiceActor: Manages 4 replicas, Health checks, Routes calls"] subgraph Replicas["4 Independent Replicas"] subgraph R0["Replica 0"] PM0["ProcMesh: 1 process, GPU 0"] - AM0["ActorMesh
1 PolicyActor"] + AM0["ActorMesh
1 Policy Actor"] end subgraph R1["Replica 1"] PM1["ProcMesh: 1 process, GPU 1"] - AM1["ActorMesh
1 PolicyActor"] + AM1["ActorMesh
1 Policy Actor"] end subgraph R2["Replica 2"] PM2["ProcMesh: 1 process, GPU 2"] - AM2["ActorMesh
1 PolicyActor"] + AM2["ActorMesh
1 Policy Actor"] end subgraph R3["Replica 3"] PM3["ProcMesh: 1 process, GPU 3"] - AM3["ActorMesh
1 PolicyActor"] + AM3["ActorMesh
1 Policy Actor"] end end @@ -232,9 +232,9 @@ graph TD ServiceActor["ServiceActor: Selects healthy replica, Load balancing, Failure handling"] - SelectedReplica["Selected Replica #2: ProcMesh 1 process, ActorMesh 1 PolicyActor"] + SelectedReplica["Selected Replica #2: ProcMesh 1 process, ActorMesh 1 Policy Actor"] - PolicyActor["PolicyActor Instance: Loads model, Runs vLLM inference"] + PolicyActor["Policy Actor Instance: Loads model, Runs vLLM inference"] GPU["GPU 2: vLLM engine, Model weights, KV cache, CUDA kernels"] diff --git a/pyproject.toml b/pyproject.toml index 886ed672c..1d0ff0cf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "tomli>=1.1.0", "anyio", "pytest-asyncio", + "multiprocess", ] oss = [ "torch", @@ -49,10 +50,6 @@ oss = [ "torchstore", ] -[project.scripts] -forge = "forge.cli.forge:main" - - # ---- Explicit project build information ---- # [build-system] requires = ["setuptools>=61.0"] diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/actors/_torchstore_utils.py index bc0d55c3b..2d14f7f30 100644 --- a/src/forge/actors/_torchstore_utils.py +++ b/src/forge/actors/_torchstore_utils.py @@ -10,6 +10,7 @@ import torch import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.metadata import Metadata as DcpMeta +from torchstore.transport.buffers import rdma_available logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -69,3 +70,8 @@ def extract_param_name(key: str) -> str: def get_dcp_whole_state_dict_key(policy_version: int) -> str: return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}" + + +def rdma_enabled() -> bool: + """Return if TorchStore thinks we're using RDMA""" + return rdma_available() diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 0dc385cc0..a536d46ae 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -13,10 +13,34 @@ from collections.abc import Mapping from copy import copy from dataclasses import dataclass, field +from typing import Optional import torch import torchstore as ts -from monarch.actor import current_rank, endpoint, ProcMesh + +from forge.actors._torchstore_utils import ( + extract_param_name, + get_dcp_whole_state_dict_key, + get_param_key, + get_param_prefix, + load_tensor_from_dcp, + rdma_available, +) + +from forge.controller import ( + ForgeActor, + get_proc_mesh, + host_mesh_from_proc, + stop_proc_mesh, +) +from forge.data_models.completion import Completion +from forge.data_models.prompt import to_prompt +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer +from forge.types import ProcessConfig +from forge.util._shared_tensor import SharedTensor, SharedTensorHandle +from monarch.actor import current_rank, endpoint, ProcMesh, this_host + from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -40,27 +64,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.actors._torchstore_utils import ( - extract_param_name, - get_dcp_whole_state_dict_key, - get_param_key, - get_param_prefix, - load_tensor_from_dcp, -) - -from forge.controller import ( - ForgeActor, - get_proc_mesh, - host_mesh_from_proc, - stop_proc_mesh, -) -from forge.data_models.completion import Completion -from forge.data_models.prompt import to_prompt -from forge.env import TORCHSTORE_USE_RDMA -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer -from forge.types import ProcessConfig - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -92,6 +95,8 @@ class Generator(ForgeActor): engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) use_dcp_for_weight_sync: bool | None = None + prefetch_weights_to_shm: bool = True + n_fetcher_procs: int = 8 def __post_init__(self): super().__init__() @@ -112,7 +117,7 @@ def __post_init__(self): self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY if self.use_dcp_for_weight_sync is None: - self.use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA.get_value() + self.use_dcp_for_weight_sync = not rdma_available() logger.debug(f"{self.use_dcp_for_weight_sync=}") @endpoint @@ -149,7 +154,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] worker_procs = await get_proc_mesh(process_config=process_config) # Then, grab a single host from the workers... - host_mesh = await host_mesh_from_proc(worker_procs) + host_mesh = await host_mesh_from_proc.call_one(worker_procs) singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()} host_mesh = host_mesh.slice(**singleton_slice) @@ -226,11 +231,61 @@ async def setup(self): log_stats=None, ) self._start_processing() + if self.prefetch_weights_to_shm: + self._spawn_fetchers() + + def _spawn_fetchers(self): + """Spawn weight fetchers that prefetch weights from torchstore to shared memory.""" + # TODO: this assumes the generator is on the same host as the worker + # and only works for single host generators. Figure out how to support + # generators with workers spanned across multiple hosts. + fetcher_procs = this_host().spawn_procs( + per_host={"procs": self.n_fetcher_procs} + ) + self._fetcher_procs = fetcher_procs + self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher) def _start_processing(self): if self._run_task is None or self._run_task.done(): self._run_task = asyncio.create_task(self.run()) + async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): + for handle in state_dict.values(): + handle.drop() + + async def _fetch_weights( + self, + version: int, + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" + t = Tracer("generator_perf/_fetch_weights") + t.start() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + hf_param_names = [extract_param_name(key) for key in matching_keys] + + n_fetchers = self.weight_fetchers.size() + + def split_keys(keys): + return [keys[i::n_fetchers] for i in range(n_fetchers)] + + futures = [] + for i, names in enumerate(split_keys(hf_param_names)): + fut = self.weight_fetchers.slice(procs=i).fetch.call_one( + version=version, param_names=names + ) + futures.append(fut) + + sub_state_dicts = [await fut for fut in futures] + + state_dict = {} + for sd in sub_state_dicts: + state_dict.update(sd) + + t.stop() + + return state_dict + @endpoint async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: """Generate a response for the given prompt @@ -384,6 +439,12 @@ async def update_weights(self, version: int) -> None: >>> await trainer.push_weights() >>> generator.update_weights(version) """ + # TODO: enable shared memory prefetch for DCP-based weight sync + if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: + logger.info(f"[Generator] Fetching weights for v{version} to shared memory") + fetch_fut = asyncio.create_task(self._fetch_weights(version)) + else: + fetch_fut = None # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests @@ -415,8 +476,19 @@ async def update_weights(self, version: int) -> None: ) logger.debug(f"Starting weight update on {self.__class__.__name__}") - # Call update_weights on every generator worker - await self.worker.update_weights.call(version=version) + + if fetch_fut is not None: + t = Tracer("generator_perf/waiting_for_fetch_weights") + t.start() + fetched_weights = await fetch_fut + t.stop() + # Call update_weights on every policy_worker + await self.worker.update_weights.call( + shared_memory_state_dict=fetched_weights + ) + await self._drop_shared_memory(fetched_weights) + else: + await self.worker.update_weights.call(version=version) self.generator_version = version # After updating the weights, we need to reset the KV cache @@ -488,18 +560,20 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] # TODO - may want to expand stop to gracefully respond to # ongoing requests. await actor.stop.call() - await stop_proc_mesh(actor._worker_procs) - await stop_proc_mesh(actor._generator_proc) + await stop_proc_mesh.call_one(actor._worker_procs) + await stop_proc_mesh.call_one(actor._generator_proc) @endpoint - async def save_model_params(self): - """Used for debugging purpose. Save model parameters before weight update.""" - await self.worker.save_model_params.call() + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[Generator] save model parameters for testing.") + await self.worker._test_save_model_params.call() @endpoint - async def validate_model_params(self, validate_fn): - """Used for debugging purpose. Validate saved params using validate_fn.""" - return await self.worker.validate_model_params.call(validate_fn) + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[Generator] start validating model parameters.") + return await self.worker._test_validate_model_params.call(validate_fn) @dataclass @@ -512,6 +586,8 @@ class GeneratorWorker(ForgeActor): """ vllm_config: VllmConfig + # TODO: Remove below param + _test_prev_params = {} @endpoint async def setup(self): @@ -569,14 +645,42 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: return self.worker.execute_model(schedule) @endpoint - async def update_weights(self, version: int) -> None: + async def update_weights( + self, + version: Optional[int] = None, + *, + shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None, + ) -> None: model = self.worker.model_runner.model + if shared_memory_state_dict is not None: + logger.info("[PolicyWorker] update weights from shared memory.") + t = Tracer( + "generator_worker_perf/update_weights_from_shared_memory", timer="gpu" + ) + t.start() + loaded_weights = set() + for name, param_handle in shared_memory_state_dict.items(): + # Use context manager for automatic cleanup + with param_handle.to_shared_tensor() as shared_tensor: + param = shared_tensor.tensor + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) + logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters") + t.stop() + return + # normal update_weights without shared memory prefetching + if version is None: + raise ValueError( + "version must be provided if not using shared_memory_state_dict" + ) + logger.info("[PolicyWorker] update weights from torchstore.") prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys loaded_weights = set() - t = Tracer("worker_perf/update_weights", timer="gpu") + t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu") t.start() if use_dcp_for_weight_sync: @@ -601,19 +705,44 @@ async def update_weights(self, version: int) -> None: t.stop() @endpoint - async def save_model_params(self): - """Used for debugging purposes. Save model parameters before weight update.""" - self._debug_saved_params = {} + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[GeneratorWorker] save model parameters for testing.") for name, param in self.worker.model_runner.model.named_parameters(): - self._debug_saved_params[name] = param.detach().cpu() + self._test_prev_params[name] = param.detach().cpu() logger.info( "[GeneratorWorker] finished saving model parameters, len = %d", - len(self._debug_saved_params), + len(self._test_prev_params), ) @endpoint - async def validate_model_params(self, validate_fn): - """Used for debugging purposes. Validate saved params using validate_fn.""" + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[GeneratorWorker] start validating model parameters.") return validate_fn( - self._debug_saved_params, self.worker.model_runner.model, logger + self._test_prev_params, self.worker.model_runner.model, logger ) + + +class _WeightFetcher(ForgeActor): + """Fetches weights from torchstore and loads them into shared memory. + This has to be colocated with the GeneratorWorker.""" + + @endpoint + async def fetch( + self, + *, + version: int, + param_names: list[str], + ) -> dict[str, SharedTensorHandle]: + """Fetch weights from torchstore and load them into shared memory.""" + sd = {} + for name in param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + # Use context manager to ensure cleanup after getting handle + with SharedTensor(tensor=param) as shared_tensor: + handle = shared_tensor.get_handle() + sd[name] = handle + del param # Explicitly free the tensor after copying to shared memory + return sd diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 71049bc52..c98c836fa 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -41,11 +41,11 @@ DcpHandle, get_dcp_whole_state_dict_key, get_param_key, + rdma_available, ) from forge.controller import ForgeActor from forge.data.utils import batch_to_device -from forge.env import TORCHSTORE_USE_RDMA from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer @@ -131,9 +131,7 @@ class RLTrainer(ForgeActor): # Non JobConfig-related fields loss: Callable = lambda logits, **targets: logits state_dict_key: str = "model_state_dict" - use_dcp: bool = ( - TORCHSTORE_USE_RDMA.get_value() == 0 - ) # torchstore currently only accepts 0 or 1 + use_dcp: bool = not rdma_available() dcp_path: str = "forge_dcp_tmp" def __post_init__(self): diff --git a/src/forge/cli/download.py b/src/forge/cli/download.py deleted file mode 100644 index 69ebde9aa..000000000 --- a/src/forge/cli/download.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse - -import json -import os -import textwrap -import traceback - -from pathlib import Path - -from huggingface_hub import snapshot_download -from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError - -from forge.cli.subcommand import Subcommand - -# TODO: update this -REPO_ID_FNAME = "original_repo_id" - - -class Download(Subcommand): - """Holds all the logic for the `forge download` subcommand.""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self._parser = subparsers.add_parser( - "download", - prog="forge download", - usage="forge download [OPTIONS]", - help="Download a model from the Hugging Face Hub.", - description="Download a model from the Hugging Face Hub.", - epilog=textwrap.dedent( - """\ - examples: - # Download a model from the Hugging Face Hub with a Hugging Face API token - $ forge download meta-llama/Llama-2-7b-hf --hf-token - Successfully downloaded model repo and wrote to the following locations: - /tmp/Llama-2-7b-hf/config.json - /tmp/Llama-2-7b-hf/README.md - /tmp/Llama-2-7b-hf/consolidated.00.pth - ... - - # Download an ungated model from the Hugging Face Hub - $ forge download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/model - Successfully downloaded model repo and wrote to the following locations: - /tmp/model/config.json - /tmp/model/README.md - /tmp/model/model-00001-of-00002.bin - ... - - For a list of all models, visit the Hugging Face Hub - https://huggingface.co/models. - """ - ), - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self._parser.set_defaults(func=self._download_cmd) - - def _add_arguments(self) -> None: - """Add arguments to the parser.""" - self._parser.add_argument( - "repo_id", - type=str, - help="Name of the repository on Hugging Face Hub.", - ) - self._parser.add_argument( - "--output-dir", - type=Path, - required=False, - default=None, - help="Directory in which to save the model. Defaults to `/tmp/`.", - ) - self._parser.add_argument( - "--hf-token", - type=str, - required=False, - default=os.getenv("HF_TOKEN", None), - help="Hugging Face API token. Needed for gated models like Llama2.", - ) - self._parser.add_argument( - "--ignore-patterns", - type=str, - required=False, - help="If provided, files matching any of the patterns are not downloaded. Example: '*.safetensors'. " - "Only supported for Hugging Face Hub models.", - ) - - def _download_cmd(self, args: argparse.Namespace) -> None: - return self._download_from_huggingface(args) - - def _download_from_huggingface(self, args: argparse.Namespace) -> None: - """Downloads a model from the Hugging Face Hub.""" - # Download the tokenizer and PyTorch model files - - # Default output_dir is `/tmp/` - output_dir = args.output_dir - if output_dir is None: - model_name = args.repo_id.split("/")[-1] - output_dir = Path("/tmp") / model_name - - print(f"Ignoring files matching the following patterns: {args.ignore_patterns}") - try: - true_output_dir = snapshot_download( - args.repo_id, - local_dir=output_dir, - ignore_patterns=args.ignore_patterns, - token=args.hf_token, - ) - except GatedRepoError: - if args.hf_token: - self._parser.error( - "It looks like you are trying to access a gated repository. Please ensure you " - "have access to the repository." - ) - else: - self._parser.error( - "It looks like you are trying to access a gated repository. Please ensure you " - "have access to the repository and have provided the proper Hugging Face API token " - "using the option `--hf-token` or by running `huggingface-cli login`." - "You can find your token by visiting https://huggingface.co/settings/tokens" - ) - except RepositoryNotFoundError: - self._parser.error( - f"Repository '{args.repo_id}' not found on the Hugging Face Hub." - ) - except Exception as e: - tb = traceback.format_exc() - msg = f"Failed to download {args.repo_id} with error: '{e}' and traceback: {tb}" - self._parser.error(msg) - - # save the repo_id. This is necessary because the download step is a separate command - # from the rest of the CLI. When saving a model adapter, we have to add the repo_id - # to the adapter config. - # TODO: this needs to be updated when we start using HF cache - file_path = os.path.join(true_output_dir, REPO_ID_FNAME + ".json") - with open(file_path, "w") as json_file: - json.dump({"repo_id": args.repo_id}, json_file, indent=4) - - print( - "Successfully downloaded model repo and wrote to the following locations:", - *list(Path(true_output_dir).iterdir()), - sep="\n", - ) diff --git a/src/forge/cli/forge.py b/src/forge/cli/forge.py deleted file mode 100644 index 7e5d2ac73..000000000 --- a/src/forge/cli/forge.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse - -from forge.cli.download import Download -from forge.cli.run import Run - - -class ForgeCLIParser: - """Holds all information related to running the CLI""" - - def __init__(self): - # Initialize the top-level parser - self._parser = argparse.ArgumentParser( - prog="forge", - description="Welcome to the torchforge CLI!", - add_help=True, - ) - # Default command is to print help - self._parser.set_defaults(func=lambda args: self._parser.print_help()) - - # Add subcommands - subparsers = self._parser.add_subparsers(title="subcommands") - Download.create(subparsers) - Run.create(subparsers) - - def parse_args(self) -> argparse.Namespace: - """Parse CLI arguments""" - return self._parser.parse_args() - - def run(self, args: argparse.Namespace) -> None: - """Execute CLI""" - args.func(args) - - -def main(): - parser = ForgeCLIParser() - args = parser.parse_args() - parser.run(args) - - -if __name__ == "__main__": - main() diff --git a/src/forge/cli/run.py b/src/forge/cli/run.py deleted file mode 100644 index 4a556c1f8..000000000 --- a/src/forge/cli/run.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import os -import sys -import textwrap - -from pathlib import Path - -from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run - -import forge -from forge.cli.subcommand import Subcommand - -ROOT = Path(forge.__file__).parent.parent - - -class Run(Subcommand): - """Holds all the logic for the `forge run` subcommand.""" - - def __init__(self, subparsers): - super().__init__() - self._parser = subparsers.add_parser( - "run", - prog="forge run", - help="Run a recipe. For distributed recipes, this supports all torchrun arguments.", - description="Run a recipe. For distributed recipes, this supports all torchrun arguments.", - usage="forge run [TORCHRUN-OPTIONS] --config [RECIPE-OPTIONS]", - epilog=textwrap.dedent( - """\ - examples: - - # Run SFT recipe with default values - $ forge run --nproc_per_node 4 apps/sft/sft.py --config apps/sft/configs/llama3_8b.yaml - """ - ), - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self._parser.set_defaults(func=self._run_cmd) - - def _add_arguments(self) -> None: - """Add arguments to the parser. - - This is a bit hacky since we need to add the torchrun arguments to our parser. - This grabs the argparser from torchrun, iterates over it's actions, and adds them - to our parser. We rename the training_script and training_script_args to recipe and recipe_args - respectively. In addition, we leave out the help argument since we add it manually to ours. - """ - torchrun_argparser = get_torchrun_args_parser() - for action in torchrun_argparser._actions: - if action.dest == "training_script": - action.dest = "recipe" - action.help = """Path to recipe to be launched followed by args.""" - elif action.dest == "training_script_args": - action.dest = "recipe_args" - action.help = "Args to be passed to the recipe." - elif action.dest == "help": - continue - self._parser._add_action(action) - - @record - def _run_distributed(self, args: argparse.Namespace): - """Run a recipe with torchrun.""" - print("Running with torchrun...") - # Have to reset the argv so that the recipe can be run with the correct arguments - args.training_script = args.recipe - args.training_script_args = args.recipe_args - - # If the user does not explicitly pass a rendezvous endpoint, run in standalone mode. - # This allows running multiple distributed training jobs simultaneously. - if not args.rdzv_endpoint: - args.standalone = True - - args.module = True - run(args) - - def _convert_to_dotpath(self, recipe_path: str) -> str: - """Convert a custom recipe path to a dot path that can be run as a module. - - Args: - recipe_path (str): The path of the recipe. - - Returns: - The dot path of the recipe. - """ - filepath, _ = os.path.splitext(recipe_path) - return filepath.replace("/", ".") - - def _run_cmd(self, args: argparse.Namespace): - """Run a recipe.""" - # We have to assume that the recipe supports distributed training - supports_distributed = True - recipe_path, config_path = None, None - - # Try to find config string in args - try: - config_idx = args.recipe_args.index("--config") + 1 - config_str = args.recipe_args[config_idx] - except ValueError: - self._parser.error("The '--config' argument is required.") - - # Get recipe path - recipe_path = self._convert_to_dotpath(args.recipe) - - # Get config path - config_path = config_str - - # Prepare args - args.recipe = recipe_path - args.recipe_args[config_idx] = config_path - - # Make sure user code in current directory is importable - sys.path.append(os.getcwd()) - - self._run_distributed(args) diff --git a/src/forge/cli/subcommand.py b/src/forge/cli/subcommand.py deleted file mode 100644 index db298a0b0..000000000 --- a/src/forge/cli/subcommand.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -class Subcommand: - def __init__(self, *args, **kwargs): - pass - - @classmethod - def create(cls, *args, **kwargs): - return cls(*args, **kwargs) - - def _add_arguments(self): - pass diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index a579200e9..763782843 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. from .actor import ForgeActor from .provisioner import ( + get_or_create_provisioner, get_proc_mesh, host_mesh_from_proc, - init_provisioner, shutdown, stop_proc_mesh, ) @@ -16,7 +16,7 @@ "ForgeActor", "get_proc_mesh", "stop_proc_mesh", - "init_provisioner", "shutdown", "host_mesh_from_proc", + "get_or_create_provisioner", ] diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index 1b8d0a074..627461bbd 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -174,7 +174,7 @@ async def as_service( await service.__initialize__() service_interface = ServiceInterface(service, cls) # Register this service with the provisioner so it can cleanly shut this down - await register_service(service_interface) + await register_service.call_one(service_interface) return service_interface @endpoint @@ -234,7 +234,7 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: logger.info(f"Spawning actor {cls.__name__}") actor = await cls.launch(*args, **actor_kwargs) # Register this actor with the provisioner so it can cleanly shut this down - await register_actor(actor) + await register_actor.call_one(actor) return actor @classmethod @@ -244,4 +244,4 @@ async def shutdown(cls, actor: "ForgeActor"): """ if actor._proc_mesh is None: raise AssertionError("Called shutdown on a replica with no proc_mesh.") - await stop_proc_mesh(actor._proc_mesh) + await stop_proc_mesh.call_one(actor._proc_mesh) diff --git a/src/forge/controller/launcher.py b/src/forge/controller/launcher.py index 65ece6597..333acbe32 100644 --- a/src/forge/controller/launcher.py +++ b/src/forge/controller/launcher.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Launcher specific logic (i.e. SLURM, k8s when supported, etc.)""" + +import copy import getpass import os import subprocess - import tempfile import uuid from typing import Any @@ -46,6 +48,58 @@ LAUNCHER_KEY = "launcher" +def mount_mnt_directory(mount_dst: str) -> None: + """Mounts the MAST remote directory to the specified destination. + + This function mounts a remote workspace directory that contains huggingface models + and other shared resources needed for training. + + Args: + mount_dst: Destination path where the directory should be mounted (e.g., "/mnt/wsfuse") + """ + # Sanity check of the mounted directory + sanity_path = os.path.join(mount_dst, "huggingface_models/") + if os.path.exists(sanity_path): + return + + # Otherwise, mount the directory + if not os.path.exists(mount_dst): + os.makedirs(mount_dst, exist_ok=True) + + # Store original LD_LIBRARY_PATH to restore after mounting + original_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + + try: + clean_env = os.environ.copy() + if "LD_LIBRARY_PATH" in clean_env: + del clean_env["LD_LIBRARY_PATH"] + + subprocess.run( + [ + "/packages/oil.oilfs/oilfs-wrapper", + "ws://ws.ai.pci0ai/genai_fair_llm", + mount_dst, + ], + capture_output=True, + text=True, + check=True, + env=clean_env, + ) + print("Done mounting") + except subprocess.CalledProcessError as e: + print(f"Get error during mounting {e}, Stderr: {e.stderr}, Stdout: {e.stdout}") + finally: + # Restore original LD_LIBRARY_PATH + if original_ld_library_path: + os.environ["LD_LIBRARY_PATH"] = original_ld_library_path + elif "LD_LIBRARY_PATH" in os.environ: + del os.environ["LD_LIBRARY_PATH"] + + assert os.path.exists( + sanity_path + ), f"Did not find directory {sanity_path}; something wrong with mounting." + + class MastSetupActor(Actor): @endpoint def mount(self, mount_dst: str): @@ -56,53 +110,7 @@ def mount(self, mount_dst: str): if current_rank().rank % proc_count != 0: # Only use one rank per host to mount the directory return - self.mount_mnt_directory(mount_dst) - - def mount_mnt_directory(self, mount_dst: str) -> None: - # Sanity check of the mounted directory - sanity_path = os.path.join(mount_dst, "huggingface_models/") - if os.path.exists(sanity_path): - print(f"Found directory {sanity_path}; skip mounting.") - return - - # Otherwise, mount the directory - if not os.path.exists(mount_dst): - os.makedirs(mount_dst, exist_ok=True) - - # Store original LD_LIBRARY_PATH to restore after mounting - original_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") - - try: - clean_env = os.environ.copy() - if "LD_LIBRARY_PATH" in clean_env: - del clean_env["LD_LIBRARY_PATH"] - - subprocess.run( - [ - "/packages/oil.oilfs/oilfs-wrapper", - "ws://ws.ai.pci0ai/genai_fair_llm", - mount_dst, - ], - capture_output=True, - text=True, - check=True, - env=clean_env, - ) - print("Done mounting") - except subprocess.CalledProcessError as e: - print( - f"Get error during mounting {e}, Stderr: {e.stderr}, Stdout: {e.stdout}" - ) - finally: - # Restore original LD_LIBRARY_PATH - if original_ld_library_path: - os.environ["LD_LIBRARY_PATH"] = original_ld_library_path - elif "LD_LIBRARY_PATH" in os.environ: - del os.environ["LD_LIBRARY_PATH"] - - assert os.path.exists( - sanity_path - ), f"Did not find directory {sanity_path}; something wrong with mounting." + mount_mnt_directory(mount_dst) class BaseLauncher: @@ -157,18 +165,49 @@ async def remote_setup(self, procs: ProcMesh) -> None: return -class Mastlauncher(BaseLauncher): - def __init__(self, cfg: LauncherConfig | None = None): +class MastLauncher(BaseLauncher): + """Launcher for MAST (Meta's internal cluster scheduler). + + This launcher supports two modes of operation: + + 1. Non-detached mode (detached=False): + - Client runs on your local machine/devserver + - Only worker roles (GPU hosts) are launched in MAST + - Client connects to workers remotely via provisioner + + 2. Detached mode (detached=True): + - Client runs entirely inside MAST as a separate role + - Both client role (CPU-only) and worker roles (GPU) are launched in MAST + - Client role executes the training script with --mode=remote + - Everything runs in the cluster, no client needed on local machine + + Args: + cfg: Launcher configuration including job name, services, and actors + detached: If True, adds a client role to the MAST job appdef that runs + the training script inside MAST. If False, only launches worker + roles and expects the client to run on local machine. + extra_args: Additional CLI arguments to pass through to the client role. + + """ + + def __init__( + self, + cfg: LauncherConfig | None = None, + detached: bool = False, + extra_args: list = None, + ): assert cfg is not None self.cfg = cfg + self.detached = detached self.default_monarch_port = 26600 + self.extra_args = extra_args or [] self.scheduler_name = "mast_conda" - # TODO: enabe taking this from config + # TODO: enable taking this from config self.sku = "gtt_any" self.timeout_sec = 1 * 60 * 60 # Kill the job if idle for 1 hour self.user = getpass.getuser() - self.work_dir = f"/data/users/{self.user}" + self.work_dir = f"/home/{self.user}" self.edittable_workspaces = ["forge"] self.remote_work_dir = "/packages/monarch_default_workspace/workspace/" self.editable_workspace_paths = [ @@ -182,8 +221,6 @@ async def initialize(self) -> None: # This can be removed in the future once this has been removed. configure(default_transport=ChannelTransport.MetaTlsWithHostname) - await self.launch_mast_job() - async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]: allocator = MastAllocator( MastAllocatorConfig( @@ -255,11 +292,15 @@ def build_appdef(self) -> specs.AppDef: "TORCHDYNAMO_VERBOSE": "1", "VLLM_TORCH_COMPILE_LEVEL": "0", "VLLM_USE_TRITON_FLASH_ATTN": "0", + "WANDB_MODE": "offline", + "HF_HUB_OFFLINE": "1", + "MONARCH_HOST_MESH_V1_REMOVE_ME_BEFORE_RELEASE": "1", + "TORCHSTORE_RDMA_ENABLED": "1", + "HF_HOME": "/mnt/wsfuse/teamforge/hf", + "TRANSFORMERS_OFFLINE": "1", }, } - print("DEFAULT ENVS: ", default_envs) - packages = Packages() meshes = [] # Process both services and actors configurations @@ -289,6 +330,15 @@ def build_appdef(self) -> specs.AppDef: timeout_sec=self.timeout_sec, env=default_envs, ) + appdef.metadata["mast"] = { + "HpcJobDefinition": { + "networkAffinity": { + # Ensure colocation + "preferredScope": 3, # DC + "fallbackScope": 3, # REGION + }, + }, + } for role in appdef.roles: role.resource.capabilities["server_sub_types"] = [ @@ -296,8 +346,45 @@ def build_appdef(self) -> specs.AppDef: role.resource.capabilities["server_sub_types"][1] # GTT ] + # Add client role to run in MAST if in detached mode + if self.detached: + client_role = self._create_client_role(appdef) + appdef.roles.insert(0, client_role) + return appdef + def _create_client_role(self, appdef: specs.AppDef) -> specs.Role: + # Clone an existing worker role to inherit workspace configuration + if not appdef.roles: + raise ValueError( + "Cannot create client role: no worker roles exist to clone from" + ) + + # Clone the first worker role + client_role = copy.deepcopy(appdef.roles[0]) + + # Override with client-specific configuration + client_role.name = "client" + # Use the bootstrap script as entrypoint + client_role.entrypoint = "workspace/forge/.meta/mast/client_bootstrap.sh" + + # Build args for the client role (passed to the bootstrap script) + # These args will be passed to client_bootstrap.sh which forwards them to main.py + args = [ + "--mode=remote", + "--job-name", + self.job_name, + ] + + # Add any extra args passed from the CLI (includes --config and other args) + if self.extra_args: + args.extend(self.extra_args) + + client_role.args = args + client_role.num_replicas = 1 + + return client_role + def create_job_name(self): return f"{self.user}-forge-{uuid.uuid4().hex[:6]}" @@ -315,6 +402,6 @@ def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None: raise ValueError( "MAST imports did not succeed, cannot launch MAST jobs. Please verify your installation" ) - return Mastlauncher(cfg) + return MastLauncher(cfg, detached=False) else: raise ValueError(f"Unsupported config provided, got {cfg}") diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 5549a8cce..b030e0668 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -13,19 +13,26 @@ import socket import uuid +from forge.controller.launcher import BaseLauncher, get_launcher +from forge.env import all_env_vars, FORGE_DISABLE_METRICS +from forge.types import ProcessConfig, ProvisionerConfig + from monarch._src.actor.actor_mesh import ActorMesh from monarch._src.actor.shape import Extent -from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host +from monarch.actor import ( + Actor, + endpoint, + get_or_spawn_controller, + HostMesh, + ProcMesh, + this_host, +) from monarch.tools import commands from monarch.utils import setup_env_for_distributed -from forge.controller.launcher import BaseLauncher, get_launcher -from forge.env import all_env_vars, FORGE_DISABLE_METRICS -from forge.types import ProcessConfig, ProvisionerConfig - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -95,7 +102,7 @@ def release_gpus(self, gpu_ids: list[str]) -> None: self.available_gpus.add(int(gpu_id)) -class Provisioner: +class Provisioner(Actor): """A global resource provisioner.""" def __init__(self, cfg: ProvisionerConfig | None = None): @@ -138,11 +145,13 @@ def __init__(self, cfg: ProvisionerConfig | None = None): self._registered_actors: list["ForgeActor"] = [] self._registered_services: list["ServiceInterface"] = [] + @endpoint async def initialize(self): """Call this after creating the instance""" if self.launcher is not None: await self.launcher.initialize() + @endpoint async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: """Creates a remote server and a HostMesh on it.""" # no need to lock here because this is already locked behind `get_proc_mesh` @@ -172,6 +181,7 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: ) return host_mesh, server_name + @endpoint def get_host_mesh(self, name: str) -> HostMesh: """Returns the host mesh given its associated name. @@ -181,6 +191,7 @@ def get_host_mesh(self, name: str) -> HostMesh: """ return self._host_mesh_map[name] + @endpoint async def get_proc_mesh( self, num_procs: int, @@ -225,7 +236,7 @@ async def get_proc_mesh( created_hosts = len(self._server_names) mesh_name = f"alloc_{created_hosts}" if host_mesh is None: - host_mesh, server_name = await self.create_host_mesh( + host_mesh, server_name = await self.create_host_mesh.call_one( name=mesh_name, num_hosts=num_hosts, ) @@ -310,13 +321,15 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh - # Spawn local fetcher actor on each process and register with global logger + # Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor. + # When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh. if not FORGE_DISABLE_METRICS.get_value(): from forge.observability.metric_actors import get_or_create_metric_logger - _ = await get_or_create_metric_logger(procs) + _ = await get_or_create_metric_logger(procs, process_name=mesh_name) return procs + @endpoint async def host_mesh_from_proc(self, proc_mesh: ProcMesh): if proc_mesh not in self._proc_host_map: raise ValueError( @@ -324,6 +337,7 @@ async def host_mesh_from_proc(self, proc_mesh: ProcMesh): ) return self._proc_host_map[proc_mesh] + @endpoint async def stop_proc_mesh(self, proc_mesh: ProcMesh): """Stops a proc mesh.""" if proc_mesh not in self._proc_host_map: @@ -333,14 +347,14 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): ) return async with self._lock: - # Deregister local logger from global logger - if hasattr(proc_mesh, "_local_fetcher"): + # Deregister LocalFetcherActor from GlobalLoggingActor + if hasattr(proc_mesh, "_local_fetcher") and hasattr(proc_mesh, "_uid"): from forge.observability.metric_actors import ( get_or_create_metric_logger, ) global_logger = await get_or_create_metric_logger(proc_mesh) - await global_logger.deregister_fetcher.call_one(proc_mesh) + await global_logger.deregister_fetcher.call_one(proc_mesh._uid) if hasattr(proc_mesh, "_gpu_ids"): gpu_manager = self._host_gpu_map[proc_mesh._host._host_id] @@ -351,6 +365,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): commands.kill(server_name) del self._proc_host_map[proc_mesh] + @endpoint def register_service(self, service: "ServiceInterface") -> None: """Registers a service allocation for cleanup.""" # Import ServiceInterface here instead of at top-level to avoid circular import @@ -363,6 +378,7 @@ def register_service(self, service: "ServiceInterface") -> None: self._registered_services.append(service) + @endpoint def register_actor(self, actor: "ForgeActor") -> None: """Registers a single actor allocation for cleanup.""" @@ -371,6 +387,7 @@ def register_actor(self, actor: "ForgeActor") -> None: self._registered_actors.append(actor) + @endpoint async def shutdown_all_allocations(self): """Gracefully shut down all tracked actors and services.""" logger.info( @@ -397,29 +414,46 @@ async def shutdown_all_allocations(self): self._registered_actors.clear() self._registered_services.clear() + @endpoint async def shutdown(self): """Tears down all remaining remote allocations.""" - await self.shutdown_all_allocations() + await self.shutdown_all_allocations.call_one() async with self._lock: for server_name in self._server_names: commands.kill(server_name) -_provisioner: Provisioner | None = None +_global_provisioner: Provisioner | None = None + + +async def get_or_create_provisioner( + cfg: ProvisionerConfig | None = None, +) -> Provisioner: + """Gets or spawns the global Provisioner controller actor.""" + global _global_provisioner + if _global_provisioner is None: + _global_provisioner = await get_or_spawn_controller( + "provisioner_controller", Provisioner, cfg + ) + await _global_provisioner.initialize.call_one() + return _global_provisioner + + +# _provisioner: Provisioner | None = None -async def init_provisioner(cfg: ProvisionerConfig | None = None): - global _provisioner - if not _provisioner: - _provisioner = Provisioner(cfg) - await _provisioner.initialize() - return _provisioner +# async def init_provisioner(cfg: ProvisionerConfig | None = None): +# global _provisioner +# if not _provisioner: +# _provisioner = Provisioner(cfg) +# await _provisioner.initialize() +# return _provisioner -async def _get_provisioner(): - if not _provisioner: - await init_provisioner() - return _provisioner +# async def _get_provisioner(): +# if not _provisioner: +# await init_provisioner() +# return _provisioner async def get_proc_mesh( @@ -444,8 +478,8 @@ async def get_proc_mesh( A proc mesh. """ - provisioner = await _get_provisioner() - return await provisioner.get_proc_mesh( + provisioner = await get_or_create_provisioner() + return await provisioner.get_proc_mesh.call_one( num_procs=process_config.procs, with_gpus=process_config.with_gpus, num_hosts=process_config.hosts, @@ -464,25 +498,25 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh): API. """ - provisioner = await _get_provisioner() - return await provisioner.host_mesh_from_proc(proc_mesh) + provisioner = await get_or_create_provisioner() + return await provisioner.host_mesh_from_proc.call_one(proc_mesh) async def register_service(service: "ServiceInterface") -> None: """Registers a service allocation with the global provisioner.""" - provisioner = await _get_provisioner() - provisioner.register_service(service) + provisioner = await get_or_create_provisioner() + provisioner.register_service.call_one(service) async def register_actor(actor: "ForgeActor") -> None: """Registers an actor allocation with the global provisioner.""" - provisioner = await _get_provisioner() - provisioner.register_actor(actor) + provisioner = await get_or_create_provisioner() + provisioner.register_actor.call_one(actor) async def stop_proc_mesh(proc_mesh: ProcMesh): - provisioner = await _get_provisioner() - return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh) + provisioner = await get_or_create_provisioner() + return await provisioner.stop_proc_mesh.call_one(proc_mesh=proc_mesh) async def shutdown_metric_logger(): @@ -503,8 +537,8 @@ async def shutdown(): logger.info("Shutting down provisioner..") - provisioner = await _get_provisioner() - result = await provisioner.shutdown() + provisioner = await get_or_create_provisioner() + result = await provisioner.shutdown.call_one() logger.info("Shutdown completed successfully") return result diff --git a/src/forge/data/datasets/dataset.py b/src/forge/data/datasets/dataset.py index 57a624c67..f18d9e07e 100644 --- a/src/forge/data/datasets/dataset.py +++ b/src/forge/data/datasets/dataset.py @@ -61,7 +61,7 @@ class DatasetInfo: class TuneIterableDataset(IterableDataset, ABC): - """Base class for all torchtune iterable datasets. + """Base class for all forge iterable datasets. Datasets are composable, enabling complex structures such as: ``PackedDataset(InterleavedDataset([InterleavedDataset([ds1, ds2]), ds3]))`` diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index fca97f912..b31d16fad 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -22,8 +22,7 @@ class AlpacaToMessages(Transform): (or equivalent fields specified in column_map) columns. User messages are formed from the instruction + input columns and assistant messages are formed from the output column. Prompt templating is conditional on the presence of the "input" column, and thus is handled directly - in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class - due to this custom logic. + in this transform class. Args: column_map (dict[str, str] | None): a mapping to change the expected "instruction", "input", diff --git a/src/forge/data/tokenizer.py b/src/forge/data/tokenizer.py index d93c5d4a4..65407e131 100644 --- a/src/forge/data/tokenizer.py +++ b/src/forge/data/tokenizer.py @@ -20,7 +20,7 @@ class HuggingFaceBaseTokenizer(BaseTokenizer): """ A wrapper around Hugging Face tokenizers. See https://github.com/huggingface/tokenizers - This can be used to load from a Hugging Face tokenizer.json file into a torchtune BaseTokenizer. + This can be used to load from a Hugging Face tokenizer.json file into a forge BaseTokenizer. This class will load the tokenizer.json file from tokenizer_json_path. It will attempt to infer BOS and EOS token IDs from config.json if possible, and if not @@ -210,7 +210,7 @@ class HuggingFaceModelTokenizer(ModelTokenizer): Then, it will load all special tokens and chat template from tokenizer config file. It can be used to tokenize messages with correct chat template, and it eliminates the requirement of - the specific ModelTokenizer and custom PromptTemplate. + the specific ModelTokenizer. Args: tokenizer_json_path (str): Path to tokenizer.json file diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index b2fdaec0c..be8c13857 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -32,7 +32,7 @@ class TuneMessage: """ This class represents individual messages in a fine-tuning dataset. It supports text-only content, text with interleaved images, and tool calls. The - :class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize + :class:`~forge.interfaces.ModelTokenizer` will tokenize the content of the message using ``tokenize_messages`` and attach the appropriate special tokens based on the flags set in this class. @@ -61,8 +61,7 @@ class TuneMessage: - All ipython messages (tool call returns) should set ``eot=False``. Note: - TuneMessage class expects any image content to be a ``torch.Tensor``, as output - by e.g. :func:`~torchtune.data.load_image` + TuneMessage class expects any image content to be a ``torch.Tensor``. """ def __init__( diff --git a/src/forge/env.py b/src/forge/env.py index 1478909da..b698b8013 100644 --- a/src/forge/env.py +++ b/src/forge/env.py @@ -101,7 +101,7 @@ def get_value(self) -> Any: TORCHSTORE_USE_RDMA = EnvVar( name="TORCHSTORE_RDMA_ENABLED", - default=0, + default=1, description="Whether or not to use RDMA in TorchStore.", ) diff --git a/src/forge/envs/__init__.py b/src/forge/envs/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/src/forge/envs/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/src/forge/envs/chat.py b/src/forge/envs/chat.py deleted file mode 100644 index 24a5981a6..000000000 --- a/src/forge/envs/chat.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field - -import torch - -from forge.interfaces import Environment, Message, ModelTokenizer, Transform - -from forge.types import Action, Observation, State - - -@dataclass -class ChatAction(Action): - """Action for chat environments. - - Contains tokens that represent the action to be taken. - This interfaces directly with models. - """ - - tokens: torch.Tensor = field(default_factory=lambda: torch.tensor([])) - - def __post_init__(self): - """Validate required fields after initialization.""" - if self.tokens.numel() == 0: - raise ValueError("tokens is required and cannot be empty") - - -@dataclass -class ChatState(State): - """State of the ChatEnvironment containing message history.""" - - history_messages: list[Message] = field(default_factory=list) - history_tokens: list[torch.Tensor] = field( - default_factory=list - ) # Same len as messages - - -@dataclass -class ChatObservation(Observation): - """Observation returned by ChatEnvironment. - - Contains the message history in Huggingface format (list of dicts with role/content) - and the tokenized representation of the entire conversation. - - The environment owns the tokenizer and generates the tokens from the messages. - - Example: - messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "How tall is the Eiffel Tower?"}, - ] - tokens = tensor([1, 2, 3, 4, 5, ...]) # tokenized entire conversation - """ - - messages: list[Message] = field(default_factory=list) - tokens: torch.Tensor = field(default_factory=lambda: torch.tensor([])) - # Inherited fields from Observation ABC: reward, done, metadata - - -class ChatEnvironment(Environment): - """A chat-based environment for LLMs, designed as a blank canvas for conversation and RL. - - This environment is designed to work with language models. It provides the fundamental structure - for managing conversation state but is intentionally minimal to allow maximum flexibility. - - The environment owns the tokenizer and is responsible for managing both message history and tokens. - Actions contain only tokens that interface directly with models. - - Args: - tokenizer: A tokenizer that will be used to tokenize the conversation - system_prompt: An optional system prompt string to use during reset calls (optional) - system_role: The role of the system (at reset time). Defaults to "system" - """ - - def __init__( - self, - tokenizer: ModelTokenizer, - system_prompt: str | None = None, - system_role: str = "system", - transform: Transform | None = None, - ): - super().__init__(transform=transform) - - if not hasattr(tokenizer, "apply_chat_template"): - raise ValueError("Tokenizer must have 'apply_chat_template' method") - self.tokenizer = tokenizer - self.system_prompt = system_prompt - self.system_role = system_role - - self._state = ChatState() - - if system_prompt: - system_message: Message = {"role": system_role, "content": system_prompt} - self._state.history_messages.append(system_message) - # Tokenize the system message - system_tokens = self.tokenizer.apply_chat_template( - conversation=[system_message], tokenize=True, return_tensors="pt" # type: ignore - ) - self._state.history_tokens.append(system_tokens) - - def reset(self) -> ChatObservation: - """Reset the environment to initial state. - - Returns: - ChatObservation: Initial observation with system prompt (if any) - """ - self._state.history_messages = [] - self._state.history_tokens = [] - if self.system_prompt: - system_message: Message = { - "role": self.system_role, - "content": self.system_prompt, - } - self._state.history_messages = [system_message] - # Tokenize the system message - system_tokens = self.tokenizer.apply_chat_template( - conversation=[system_message], tokenize=True, return_tensors="pt" # type: ignore - ) - self._state.history_tokens = [system_tokens] - - return self._create_observation() - - def step(self, action: ChatAction) -> ChatObservation: - """Take a step in the environment by adding tokens to the chat history. - - Args: - action: A ChatAction object containing tokens. - - Returns: - ChatObservation: The updated observation with the new tokens added. - """ - # Store the tokens directly from the action - self._state.history_tokens.append(action.tokens) - - # Decode tokens to text and add as a message to history - decoded_text = self.tokenizer.decode( - action.tokens.squeeze(), skip_special_tokens=True - ) - assistant_message: Message = {"role": "assistant", "content": decoded_text} - self._state.history_messages.append(assistant_message) - - return self._create_observation() - - def _create_observation(self) -> ChatObservation: - """Create a ChatObservation from the current state. - - Returns both the message history and the tokens flattened as a single tensor - ready to be used by models. - - Returns: - ChatObservation: Observation with messages and flattened tokens - """ - if self._state.history_tokens: - flattened_tokens = torch.cat(self._state.history_tokens, dim=0) - else: - flattened_tokens = torch.tensor([]) - - observation = ChatObservation( - messages=self._state.history_messages.copy(), # Copy to prevent external mutation - tokens=flattened_tokens, - ) - - transformed = self._apply_transform(observation) - if isinstance(transformed, ChatObservation): - return transformed - else: - # If transform returns base Observation, convert back to ChatObservation - return ChatObservation( - messages=getattr(transformed, "messages", []), - tokens=getattr(transformed, "tokens", torch.tensor([])), - done=transformed.done, - reward=transformed.reward, - ) - - @property - def state(self) -> ChatState: - """Get the current state of the environment. - - Returns: - ChatState: The current state. - """ - return self._state - - def message_to_action(self, message: Message) -> ChatAction: - """Convert a message dictionary to a ChatAction with tokens. - - Args: - message: Dictionary with 'role' and 'content' keys - - Returns: - ChatAction: A new ChatAction instance with tokenized content - - Raises: - ValueError: If required keys are missing - """ - if "role" not in message: - raise ValueError("Message must contain a 'role' key") - if "content" not in message: - raise ValueError("Message must contain a 'content' key") - if message["content"] is None: - raise ValueError("Message content cannot be None") - - # Tokenize the single message - tokens = self.tokenizer.apply_chat_template( - conversation=[message], tokenize=True, return_tensors="pt" # type: ignore - ) - - return ChatAction(tokens=tokens) diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index df79c302e..526e36c56 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -7,11 +7,7 @@ from abc import ABC, abstractmethod from typing import Any, Mapping -from monarch.actor import endpoint - -from forge.controller import ForgeActor - -from forge.types import Action, Message, Observation, Scalar, State +from forge.types import Message, Observation, Scalar class Transform(ABC): @@ -37,68 +33,10 @@ def __call__(self, observation: Observation) -> Observation: pass -class Environment(ABC): - """Abstract base class for environments. - - Args: - transform: Optional transform that modifies observations, typically to add rewards. - Can be a Transform instance or a callable for backward compatibility. - """ - - def __init__( - self, - transform: Transform | None = None, - ): - self.transform = transform - - @abstractmethod - def reset(self) -> Observation: - """Reset the environment and return an initial observation.""" - pass - - @abstractmethod - def step(self, action: Any) -> Observation: - """Take a step in the environment and return an observation.""" - pass - - @property - @abstractmethod - def state(self) -> State: - """Get the current state of the environment.""" - pass - - def _apply_transform(self, observation: Observation) -> Observation: - """Apply the transform to an observation if one is provided.""" - if self.transform is not None: - return self.transform(observation) - return observation - - -class Policy(ForgeActor, ABC): - """Abstract interface for policies.""" - - @endpoint - @abstractmethod - async def generate(self, request: Observation) -> Action: - """Generate an action given a state/request.""" - pass - - @endpoint - @abstractmethod - async def update_weights(self, policy_version: int): - """Update the policy weights. - - Args: - policy_version: The version number to update to. - """ - pass - - class BaseTokenizer(ABC): """ Abstract token encoding model that implements ``encode`` and ``decode`` methods. - See :class:`~torchtune.modules.transforms.tokenizers.SentencePieceBaseTokenizer` and - :class:`~torchtune.modules.transforms.tokenizers.TikTokenBaseTokenizer` for example implementations of this protocol. + See :class:`forge.data.HuggingFaceModelTokenizer for an example implementation of this protocol. """ @abstractmethod @@ -133,7 +71,7 @@ def decode(self, token_ids: list[int], **kwargs: dict[str, Any]) -> str: class ModelTokenizer(ABC): """ Abstract tokenizer that implements model-specific special token logic in - the ``tokenize_messages`` method. See :class:`~torchtune.models.llama3.Llama3Tokenizer` + the ``tokenize_messages`` method. See :class:`forge.data.HuggingFaceModelTokenizer` for an example implementation of this protocol. """ @@ -210,10 +148,3 @@ class Reward(ABC): def __call__(self, observation: Observation) -> float: """Compute a reward for an observation.""" pass - - -# TODO -# class RLLoss(ABC): - -# class SFTLoss(ABC): # inherit from titan loss -# from torchtitan.components.loss import LossFunction diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index b970e57fa..8efd3dace 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -12,8 +12,6 @@ from .metrics import ( BackendRole, ConsoleBackend, - get_actor_name_with_rank, - get_logger_backend_class, LoggerBackend, MaxAccumulator, MeanAccumulator, @@ -29,12 +27,12 @@ WandbBackend, ) from .perf_tracker import trace, Tracer +from .utils import get_proc_name_with_rank __all__ = [ # Main API functions "record_metric", "reduce_metrics_states", - "get_actor_name_with_rank", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking @@ -45,6 +43,8 @@ "BackendRole", # Enums "Reduce", + # Utility functions + "get_proc_name_with_rank", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 83ddd349e..ee6fe6277 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -6,9 +6,17 @@ import asyncio import logging +import uuid from typing import Any, Union -from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc +from monarch.actor import ( + Actor, + context, + endpoint, + get_or_spawn_controller, + ProcMesh, + this_proc, +) from forge.env import FORGE_DISABLE_METRICS from forge.observability.metrics import ( @@ -27,36 +35,35 @@ async def get_or_create_metric_logger( proc_mesh: ProcMesh | None = None, + process_name: str | None = None, ) -> "GlobalLoggingActor": - """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), - if not already initialized, registers it with the GlobalLoggingActor and returns the - GlobalLoggingActor instance. + """Spawns a LocalFetcherActor for the specified ProcMesh (if not already initialized), + registers it with the GlobalLoggingActor, and returns the GlobalLoggingActor. - There are primarily two ways to use this function: - 1. In the main process, call `get_or_create_metric_logger()` to get the global logger. - 2. In service processes, call `get_or_create_metric_logger(proc_mesh)` to register the - local fetcher with the global logger. + Usage: + 1. Main process: call `get_or_create_metric_logger()` to get the global logger + 2. Service spawning: call `get_or_create_metric_logger(proc_mesh, process_name)` to register the + map(proc_mesh,local fetcher) with the global logger, so it knows to broadcast to all ranks. Args: - proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, - uses `monarch.actor.this_proc()`. + proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `this_proc()`. + process_name: Optional process name (e.g., "TrainActor") for logging. Auto-detected from the context if None. Returns: GlobalLoggingActor: The global logging controller. Raises: - ValueError: If the logging state is inconsistent, i.e. the fetcher is already - registered, but only in the process or the global logger. + ValueError: If the logging state is inconsistent. Example: from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric # Main process setup - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") # Initialize logging backends - await mlogger.init_backends({ + await mlogger.init_backends.call_one({ "console": {"reduce_across_ranks": True}, "wandb": {"project": "my_project", "reduce_across_ranks": False} }) @@ -66,15 +73,16 @@ async def get_or_create_metric_logger( # Training loop for step in range(max_steps): - record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) + record_metric("loss", 1.2, reduction_type=Reduce.MEAN) # ... training code with record_metric() calls ... - await mlogger.flush(step) # Log metrics for this step + await mlogger.flush.call_one(step) # Log metrics for this step # Shutdown - await mlogger.shutdown() + await mlogger.shutdown.call_one() """ # Get or create the singleton global logger global _global_logger + if _global_logger is None: _global_logger = await get_or_spawn_controller( "global_logger", GlobalLoggingActor @@ -84,9 +92,15 @@ async def get_or_create_metric_logger( # Determine process context proc = proc_mesh if proc_mesh is not None else this_proc() + # Auto-detect process_name from proc mesh if not provided + if process_name is None: + ctx = context() + process_name = ctx.actor_instance.actor_id.actor_name + # Check current state for consistency proc_has_local_fetcher = hasattr(proc, "_local_fetcher") - global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) + proc_id = proc._uid if proc_has_local_fetcher else None + global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc_id) # Consistency check: both should be in sync if proc_has_local_fetcher != global_logger_has_local_fetcher: @@ -101,24 +115,32 @@ async def get_or_create_metric_logger( # Setup local_fetcher_actor if needed (unless disabled by environment flag) if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value(): local_fetcher_actor = proc.spawn( - "local_fetcher_actor", LocalFetcherActor, global_logger + "local_fetcher_actor", LocalFetcherActor, global_logger, process_name ) - await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) + # Generate a unique ID to map procmesh to fetcher + proc._uid = str(uuid.uuid4()) proc._local_fetcher = local_fetcher_actor # pyre-ignore + await global_logger.register_fetcher.call_one(local_fetcher_actor, proc._uid) + return global_logger class LocalFetcherActor(Actor): - """Thin per-process actor used to trigger MetricCollector singleton - operations without direct access. It is what GlobalLoggingActor - uses to broadcast inits/flushes across ranks. + """Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh + and accesses each rank's local MetricCollector. - GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector + Flow: + GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger """ - def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None: + def __init__( + self, + global_logger: Union["GlobalLoggingActor", None] = None, + process_name: str | None = None, + ) -> None: self.global_logger = global_logger + self.process_name = process_name _is_initialized = False @endpoint @@ -145,10 +167,22 @@ async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]], config: dict[str, Any], + global_step: int = 0, ) -> None: - """Init local (per-rank) logger backends and MetricCollector.""" + """Init per-rank logger backends and MetricCollector. + + Args: + metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. + config (dict[str, Any]): Backend configurations with logging modes and settings. + global_step (int): Initial step for metrics. + """ collector = MetricCollector() - await collector.init_backends(metadata_per_primary_backend, config) + await collector.init_backends( + metadata_per_primary_backend, + config, + global_step, + process_name=self.process_name, + ) @endpoint async def shutdown(self) -> None: @@ -157,22 +191,17 @@ async def shutdown(self) -> None: class GlobalLoggingActor(Actor): - """Coordinates metric logging across all ranks for every training step. + """Coordinates metric logging across all ProcMeshes and their ranks. Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), - for per-rank and/or global reduction logging modes. + with per-rank and/or global reduction logging modes. If a backend config has flag `reduce_across_ranks=False`, an instance of the backend is initialized per-rank, otherwise it is done once globally. - This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor - is automatically spawned per-rank in `forge.controller.provisioner.py` and registered - with this actor. The LocalFetcherActor is responsible for instantiating - the per-rank MetricCollector. - In summary, the flow is: - - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector - - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush + Flow: + GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger """ def __init__(self): @@ -203,9 +232,12 @@ async def init_backends(self, config: dict[str, Any]) -> None: """ self.config = config + if FORGE_DISABLE_METRICS.get_value(): + return + for backend_name, backend_config in config.items(): backend = get_logger_backend_class(backend_name)(backend_config) - await backend.init(role=BackendRole.GLOBAL) + await backend.init(role=BackendRole.GLOBAL, name="global_reduce") # Extract metadata from primary logger to be shared with secondary loggers # and store it @@ -233,30 +265,31 @@ async def init_backends(self, config: dict[str, Any]) -> None: await asyncio.gather(*tasks, return_exceptions=True) @endpoint - async def register_fetcher( - self, fetcher: LocalFetcherActor, name: str | ProcMesh - ) -> None: - """Registers a fetcher with the global actor. Each key represents a process mesh. - If there are 2 processes, each with 2 replicas with N gpus, we would - have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" - self.fetchers[name] = fetcher # pyre-ignore + async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> None: + """Registers a LocalFetcherActor with the GlobalLoggingActor. One LocalFetcherActor per ProcMesh. + + Args: + fetcher: The LocalFetcherActor instance for a ProcMesh + proc_id: Unique identifier for the ProcMesh + """ + self.fetchers[proc_id] = fetcher # Self-init for respawned actors if self.config: - logger.debug(f"Initializing new LocalFetcherActor {name}") + logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}") await fetcher.init_backends.call( self.metadata_per_primary_backend, self.config ) @endpoint - async def deregister_fetcher(self, name: str | ProcMesh) -> None: - if name not in self.fetchers: + async def deregister_fetcher(self, proc_id: str) -> None: + if proc_id not in self.fetchers: logger.warning( - f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." + f"Fetcher {proc_id} not registered in GlobalLoggingActor. Cannot deregister." f"Available fetchers: {self.fetchers.keys()}" ) return - del self.fetchers[name] + del self.fetchers[proc_id] @endpoint async def flush(self, global_step: int) -> None: @@ -329,9 +362,9 @@ async def flush(self, global_step: int) -> None: await logger_backend.log(reduced_metrics, global_step) @endpoint - def has_fetcher(self, name: str | ProcMesh) -> bool: - """Check if a fetcher is registered with the given name.""" - return name in self.fetchers + def has_fetcher(self, proc_id: str) -> bool: + """Check if a fetcher is registered with the given proc_id.""" + return proc_id in self.fetchers @endpoint def get_fetcher_count(self) -> int: @@ -339,10 +372,17 @@ def get_fetcher_count(self) -> int: @endpoint async def shutdown(self) -> None: - # Finish per-rank logger_backends via fetchers if self.fetchers: - tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] - await asyncio.gather(*tasks, return_exceptions=True) + try: + tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] + await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions=True), timeout=2.0 + ) + except asyncio.TimeoutError: + logger.warning( + "Metric logging fetcher shutdown timed out likely due to the child process being terminated before the parent." + ) + # Finish global logger_backends for logger_backend_name, logger_backend in self.global_logger_backends.items(): await logger_backend.finish() diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 3ce849ad2..980bb89fc 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -15,6 +15,8 @@ import pytz from monarch.actor import context, current_rank +from forge.observability.utils import get_proc_name_with_rank + from forge.util.logging import log_once logger = logging.getLogger(__name__) @@ -438,11 +440,14 @@ def __init__(self) -> None: self.rank = current_rank().rank self.logger_backends: list[LoggerBackend] = [] self._is_initialized = False + self.proc_name_with_rank: str | None = None async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]] | None, config: dict[str, Any], + global_step: int = 0, + process_name: str | None = None, ) -> None: """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated @@ -452,11 +457,16 @@ async def init_backends( metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary logger backend, e.g., {"wandb": {"run_id": "abc123"}}. config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + global_step (int, default 0): Initial step for metrics. + process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: logger.debug(f"Rank {self.rank}: MetricCollector already initialized") return + self.global_step = global_step + self.proc_name_with_rank = get_proc_name_with_rank(process_name) + # instantiate local backends if any for backend_name, backend_config in config.items(): if backend_config.get("reduce_across_ranks", True): @@ -470,7 +480,9 @@ async def init_backends( # instantiate local backend logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata + role=BackendRole.LOCAL, + primary_logger_metadata=primary_metadata, + name=self.proc_name_with_rank, ) self.logger_backends.append(logger_backend) @@ -495,7 +507,8 @@ def push(self, metric: Metric) -> None: logger, level=logging.WARNING, msg=( - "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." + f"Skipping metric collection for {get_proc_name_with_rank()}." + " Metric logging backends (e.g. wandb) were not initialized." " This happens when you try to use `record_metric` before calling `init_backends`." " To disable this warning, please call in your main file:\n" "`mlogger = await get_or_create_metric_logger()`\n" @@ -534,7 +547,8 @@ async def flush( log_once( logger, level=logging.WARNING, - msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()." + msg=f"Cannot flush collected metrics for {get_proc_name_with_rank()}. " + " MetricCollector.flush() called before init_backends()." "\nPlease call in your main file:\n" "`mlogger = await get_or_create_metric_logger()`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" @@ -544,7 +558,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" + f"Collector {self.proc_name_with_rank}: No metrics to flush for global_step {global_step}" ) return {} @@ -569,7 +583,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._is_initialized: logger.debug( - f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" + f"Collector for {self.proc_name_with_rank} not initialized. Skipping shutdown" ) return @@ -593,6 +607,7 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + name: str | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). @@ -602,6 +617,7 @@ async def init( Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. + name (str | None): Name for logging. Raises: ValueError if missing metadata for shared local init. """ @@ -618,6 +634,7 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: """ pass + @abstractmethod async def finish(self) -> None: pass @@ -636,12 +653,10 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + name: str | None = None, ) -> None: - self.prefix = ( - get_actor_name_with_rank() - if self.logger_backend_config.get("reduce_across_ranks", True) - else "Controller" - ) + + self.name = name async def log(self, metrics: list[Metric], global_step: int) -> None: metrics_str = "\n".join( @@ -649,7 +664,7 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: for metric in sorted(metrics, key=lambda m: m.key) ) logger.info( - f"=== [{self.prefix}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" + f"=== [{self.name}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" ) async def finish(self) -> None: @@ -689,16 +704,13 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + name: str | None = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - self.name = ( - get_actor_name_with_rank() - if role == BackendRole.LOCAL - else "global_controller" - ) + self.name = name # Default global mode: only inits on controller if self.reduce_across_ranks: diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py new file mode 100644 index 000000000..811bbfe41 --- /dev/null +++ b/src/forge/observability/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +from monarch.actor import context, current_rank + +logger = logging.getLogger(__name__) + + +def get_proc_name_with_rank(proc_name: Optional[str] = None) -> str: + """ + Returns a unique identifier for the current rank from Monarch actor context. + + Multiple ranks from the same ProcMesh will share the same ProcMesh hash suffix, + but have different rank numbers. + + Format: "{ProcessName}_{ProcMeshHash}_r{rank}" where: + - ProcessName: The provided proc_name (e.g., "TrainActor") or extracted from actor_name if None. + - ProcMeshHash: Hash suffix identifying the ProcMesh (e.g., "1abc2def") + - rank: Local rank within the ProcMesh (0, 1, 2, ...) + + Note: If called from the main process (e.g. main.py), returns "client_r0". + + Args: + proc_name: Optional override for process name. If None, uses actor_id.actor_name. + + Returns: + str: Unique identifier per rank (e.g., "TrainActor_1abc2def_r0" or "client_r0"). + """ + ctx = context() + actor_id = ctx.actor_instance.actor_id + actor_name = actor_id.actor_name + rank = current_rank().rank + + # If proc_name provided, extract procmesh hash from actor_name and combine + if proc_name is not None: + parts = actor_name.split("_") + if len(parts) > 1: + replica_hash = parts[-1] # (e.g., "MyActor_1abc2def" -> "1abc2def") + return f"{proc_name}_{replica_hash}_r{rank}" + else: + # if a direct process (e.g. called from main), actor_name == "client" -> len(parts) == 1 + return f"{proc_name}_r{rank}" + + # No proc_name override - use full actor_name with rank + return f"{actor_name}_r{rank}" diff --git a/src/forge/types.py b/src/forge/types.py index 6a9dcc122..fa77a83de 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -15,15 +15,6 @@ class Message(TypedDict): tools: dict[str, Any] | None -@dataclass -class ForgeEnvInfo: - """Environment info returned with observations.""" - - episode_id: str | None = None - step_count: int = 0 - metadata: dict | None = None - - @dataclass(kw_only=True) class Observation: """Base class for environment observations. @@ -44,50 +35,6 @@ class Observation: metadata: dict[str, Any] = field(default_factory=dict) -@dataclass(kw_only=True) -class Action: - """Base class for environment actions. - - Contract: - - Should contain all information needed to execute a step in the environment - - Should be serializable/deserializable - - Should be immutable (or treated as such) - - Args: - metadata: Additional data that may be useful for logging, debugging, or transforms - """ - - metadata: dict[str, Any] = field(default_factory=dict) - - -@dataclass -class Trajectory: - """A trajectory containing a sequence of states, actions, etc.""" - - policy_version: int - states: list[Observation] = field(default_factory=list) - actions: list[Action] = field(default_factory=list) - - def __post_init__(self): - assert self.policy_version >= 0 - - -@dataclass(kw_only=True) -class State: - """Base class for environment state. - - Contract: - - Should contain all information needed to restore the environment - - Should be serializable/deserializable - - May contain information not exposed in observations - - Args: - metadata: Additional state information that may be useful for debugging or analysis - """ - - metadata: dict[str, Any] = field(default_factory=dict) - - class Launcher(Enum): MAST = "mast" SLURM = "slurm" diff --git a/src/forge/util/_shared_tensor.py b/src/forge/util/_shared_tensor.py new file mode 100644 index 000000000..18a7d65e6 --- /dev/null +++ b/src/forge/util/_shared_tensor.py @@ -0,0 +1,440 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import logging + +import uuid +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@dataclass +class SharedTensorHandle: + shm_name: str + shape: Tuple[int, ...] + dtype: str + + def to_shared_tensor(self) -> SharedTensor: + """ + Create a SharedTensor from this handle. + + Returns: + SharedTensor instance attached to the shared memory referenced by this handle + """ + return SharedTensor(handle=self) + + def drop(self) -> None: + """ + Unlink the shared memory segment. + + This marks the shared memory for deletion. The actual memory will be freed + once all processes have closed their handles to it. + + Note: This only unlinks, it does not close any handles. Processes that have + opened this shared memory should call close() on their SharedTensor instances. + """ + try: + # Attach to the shared memory just to unlink it + shm = shared_memory.SharedMemory(name=self.shm_name) + shm.close() + shm.unlink() + except Exception: + pass + + +class SharedTensor: + """ + Wrapper class for tensors backed by shared memory. + + This class provides a way to share tensors between processes using POSIX shared memory. + It's designed for efficient inter-process tensor communication without copying data. + + Ownership and Lifecycle Model: + ------------------------------ + 1. **Creator process**: + - Creates SharedTensor with tensor data or empty + - Gets a handle via get_handle() to pass to other processes + - **MUST** call close() after getting handle to release its reference + - **SHOULD** call drop()/unlink() when all processes are done + + 2. **Receiver processes**: + - Receive SharedTensorHandle (via RPC, pickle, etc.) + - Create SharedTensor from handle: SharedTensor(handle=handle) + - Use the tensor: handle.to_shared_tensor().tensor + - **MUST** call close() when done using the tensor + + 3. **Cleanup**: + - close(): Closes this process's file descriptor/handle + - drop()/unlink(): Marks shared memory for deletion (call once, from any process) + - Actual memory is freed when all processes have closed AND unlink is called + + Memory Leak Prevention: + ---------------------- + - **DO NOT** rely on __del__ for cleanup! Python GC is unpredictable. + - **ALWAYS** explicitly call close() when done with a SharedTensor + - **ALWAYS** call drop() on handles when sharing is complete + - Use context manager (with statement) for automatic cleanup + - After close(), accessing .tensor will raise RuntimeError + - After close(), getting handle will raise RuntimeError + + Closed State Behavior: + --------------------- + - Once close() is called, the SharedTensor enters a closed state + - Accessing .tensor after close() raises RuntimeError + - Calling get_handle() after close() raises RuntimeError + - You can check the state with the .is_closed property + - close() and drop() are idempotent (safe to call multiple times) + + Important Warning: + ------------------ + If you hold a reference to the tensor BEFORE calling close(), that + reference becomes INVALID after close(): + t = shared.tensor # Get reference + shared.close() # Close SharedTensor - unmaps memory + t.sum() # SEGFAULT! The memory is now invalid + + After close(), the shared memory mapping is unmapped, so ALL references + to the tensor (including cached ones) point to invalid memory. Accessing + them will cause segmentation faults or undefined behavior. + + Always ensure you're done with the tensor before calling close(). + + Example Usage: + ------------- + # Creator process + tensor = torch.randn(100, 100) + shared = SharedTensor(tensor=tensor) + handle = shared.get_handle() + shared.close() # Close creator's reference + # ... send handle to other process via RPC ... + handle.drop() # Unlink after all receivers have it + + # Receiver process + # ... receive handle via RPC ... + shared = SharedTensor(handle=handle) + result = shared.tensor.sum() # Use the tensor + shared.close() # Close receiver's reference + + # Or use context manager (recommended) + with SharedTensor(handle=handle) as shared: + result = shared.tensor.sum() + # Automatically closed + """ + + def __init__( + self, + *, + tensor: Optional[torch.Tensor] = None, + handle: Optional[SharedTensorHandle] = None, + ): + if tensor is not None: + self._create_from_tensor(tensor) + elif handle is not None: + self._create_from_handle(handle) + else: + raise ValueError("Must provide either tensor or handle") + + @classmethod + def empty( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create an empty tensor directly in shared memory (no copy/allocation overhead) + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype (supports bfloat16, float32, etc.) + + Returns: + SharedTensor instance with uninitialized data + """ + instance = cls.__new__(cls) + instance._create_empty(shape, dtype) + return instance + + @classmethod + def zeros( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create a zero-initialized tensor in shared memory + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype + + Returns: + SharedTensor instance with zeros + """ + shared_tensor = cls.empty(shape, dtype) + shared_tensor.tensor.zero_() + return shared_tensor + + @classmethod + def ones( + cls, + shape: Union[Tuple[int, ...], torch.Size], + dtype: torch.dtype = torch.float32, + ): + """ + Create a ones-initialized tensor in shared memory + + Args: + shape: Shape of the tensor + dtype: PyTorch dtype + + Returns: + SharedTensor instance with ones + """ + shared_tensor = cls.empty(shape, dtype) + shared_tensor.tensor.fill_(1) + return shared_tensor + + def _create_empty(self, shape, dtype): + """Initialize with empty tensor in shared memory""" + # Initialize lifecycle state + self._closed = False + self._tensor_cache = None + + # Store metadata + self._shape = tuple(shape) if not isinstance(shape, tuple) else shape + self._dtype = dtype + self._dtype_str = str(dtype) + + # Calculate size + element_size = torch.tensor([], dtype=dtype).element_size() + total_elements = int(np.prod(self._shape)) + byte_size = total_elements * element_size + + # Create shared memory (uninitialized - fast!) + shm_name = f"shared_tensor_{uuid.uuid4().hex}" + self._shm = shared_memory.SharedMemory( + create=True, size=byte_size, name=shm_name + ) + self._shm_name = shm_name + + def _create_from_tensor(self, tensor): + """Initialize from an existing tensor""" + # Initialize lifecycle state + self._closed = False + self._tensor_cache = None + + tensor = tensor.contiguous() + + # Store metadata + self._shape = tuple(tensor.shape) + self._dtype = tensor.dtype + self._dtype_str = str(tensor.dtype) + + # Create shared memory + byte_size = tensor.numel() * tensor.element_size() + shm_name = f"shared_tensor_{uuid.uuid4().hex}" + + self._shm = shared_memory.SharedMemory( + create=True, size=byte_size, name=shm_name + ) + self._shm_name = shm_name + + # Copy data as raw bytes + raw_bytes = tensor.view(torch.uint8).view(-1).cpu().contiguous().numpy() + self._shm.buf[:byte_size] = raw_bytes + del raw_bytes # Explicitly free the intermediate numpy array + + def _create_from_handle(self, handle: SharedTensorHandle): + """Initialize from a handle""" + # Initialize lifecycle state + self._closed = False + self._tensor_cache = None + + self._shm_name = handle.shm_name + self._shape = handle.shape + self._dtype_str = handle.dtype + self._dtype = self._parse_dtype(self._dtype_str) + + # Attach to existing shared memory\ + self._shm = shared_memory.SharedMemory(name=self._shm_name) + + def _create_tensor_view(self): + """Create tensor view of shared memory.""" + element_size = torch.tensor([], dtype=self._dtype).element_size() + total_elements = int(np.prod(self._shape)) + byte_size = total_elements * element_size + + # Create numpy array that shares the buffer + np_array = np.ndarray(shape=(byte_size,), dtype=np.uint8, buffer=self._shm.buf) + # Create torch tensor from numpy (shares memory) + uint8_tensor = torch.from_numpy(np_array) + tensor = uint8_tensor.view(self._dtype).reshape(self._shape) + + # Keep the np array alive + tensor._forge_np_array = np_array + + return tensor + + def _parse_dtype(self, dtype_str): + """Parse dtype string""" + dtype_str = dtype_str.replace("torch.", "") + return getattr(torch, dtype_str) + + def get_handle(self): + """ + Get a picklable handle to share this SharedTensor with other processes. + + Returns: + SharedTensorHandle: A lightweight handle that can be pickled and sent to other processes + + Raises: + RuntimeError: If called after close() has been called + """ + if self._closed: + raise RuntimeError( + "Cannot get handle after close(). Get the handle before closing." + ) + return SharedTensorHandle( + shm_name=self._shm_name, + shape=self._shape, + dtype=self._dtype_str, + ) + + @property + def tensor(self): + """ + Get the underlying tensor. + + Returns: + torch.Tensor: View into the shared memory + + Raises: + RuntimeError: If accessed after close() has been called + """ + if self._closed: + raise RuntimeError( + "Cannot access tensor after close(). The SharedTensor has been closed." + ) + if self._tensor_cache is None: + self._tensor_cache = self._create_tensor_view() + return self._tensor_cache + + def copy_from(self, source_tensor): + """ + Copy data from another tensor into this shared tensor + Useful when you create empty tensor first, then fill it + + Args: + source_tensor: Source tensor to copy from + """ + if source_tensor.shape != self._shape: + raise ValueError(f"Shape mismatch: {source_tensor.shape} vs {self._shape}") + # Copy data + self.tensor.copy_(source_tensor) + + def clone(self): + """Create a new SharedTensor with copied data""" + new_shared = SharedTensor.empty(self._shape, self._dtype) + new_shared.tensor.copy_(self.tensor) + return new_shared + + def close(self): + """ + Close this process's handle to the shared memory. + + This should be called when this process is done using the shared memory. + The shared memory will persist until all processes have closed their handles + and someone calls unlink(). + + After calling close(), this SharedTensor object should not be used anymore. + Accessing the tensor property after close() will raise a RuntimeError. + + This method is idempotent - calling it multiple times is safe. + + Note: If you hold a reference to the tensor before calling close(), + that reference will remain valid, but new accesses via shared.tensor + will raise an error. + """ + if self._closed: + return # Already closed, nothing to do + + self._closed = True + self._tensor_cache = None # Release tensor and numpy array references + + try: + self._shm.close() + except Exception as e: + logger.error(f"Error closing shared memory {self._shm_name}: {e}") + + def drop(self): + """ + Close and unlink the shared memory. + + This method first closes this process's handle (if not already closed), + then marks the shared memory for deletion. The actual memory will be freed + once all processes have closed their handles. + + This method is idempotent - calling it multiple times is safe. + + Note: + This should be called when the shared tensor is no longer needed. + Failing to call this method may result in shared memory leaks. + """ + # Close first to set _closed flag and release cache + self.close() + + # Then unlink + try: + self._shm.unlink() + except Exception as e: + raise RuntimeError( + f"Error unlinking shared memory {self._shm_name}: {e}" + ) from e + + @property + def is_closed(self) -> bool: + """ + Check if this SharedTensor has been closed. + + Returns: + bool: True if close() has been called, False otherwise + """ + return self._closed + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - closes the shared memory handle.""" + self.close() + return False + + def __del__(self): + """ + Best-effort cleanup on garbage collection. + + WARNING: Do NOT rely on __del__ for cleanup! Python's garbage collector + may not call __del__ promptly or at all, which can cause memory leaks. + Always explicitly call close() when done with the SharedTensor. + + This __del__ is only a safety net for cases where explicit cleanup is missed. + """ + # Only close if the object was fully initialized + if hasattr(self, "_closed"): + self.close() + + def __repr__(self): + return f"SharedTensor(shape={self._shape}, dtype={self._dtype}, shm_name={self._shm_name})" diff --git a/src/forge/cli/config.py b/src/forge/util/config.py similarity index 96% rename from src/forge/cli/config.py rename to src/forge/util/config.py index a5e35cefd..2dd171ae3 100644 --- a/src/forge/cli/config.py +++ b/src/forge/util/config.py @@ -56,22 +56,22 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: list[str]) -> DictC cli args, respectively) and merges them into a single OmegaConf DictConfig. If a cli arg overrides a yaml arg with a _component_ field, the cli arg can - be specified with the parent field directly, e.g., model=torchtune.models.lora_llama2_7b - instead of model._component_=torchtune.models.lora_llama2_7b. Nested fields within the + be specified with the parent field directly, e.g., model=my_module.models.my_model + instead of model._component_=my_module.models.my_model. Nested fields within the component should be specified with dot notation, e.g., model.lora_rank=16. Example: >>> config.yaml: >>> a: 1 >>> b: - >>> _component_: torchtune.models.my_model + >>> _component_: my_module.models.my_model >>> c: 3 - >>> tune full_finetune --config config.yaml b=torchtune.models.other_model b.c=4 + >>> python main.py --config config.yaml b=my_module.models.other_model b.c=4 >>> yaml_args, cli_args = parser.parse_known_args() >>> conf = _merge_yaml_and_cli_args(yaml_args, cli_args) >>> print(conf) - >>> {"a": 1, "b": {"_component_": "torchtune.models.other_model", "c": 4}} + >>> {"a": 1, "b": {"_component_": "my_module.models.other_model", "c": 4}} Args: yaml_args (Namespace): Namespace containing args from yaml file, components diff --git a/src/forge/util/logging.py b/src/forge/util/logging.py index e47f5dfa3..9eacf893d 100644 --- a/src/forge/util/logging.py +++ b/src/forge/util/logging.py @@ -20,14 +20,17 @@ def get_logger(level: str | None = None) -> logging.Logger: Example: >>> logger = get_logger("INFO") >>> logger.info("Hello world!") - INFO:torchtune.utils._logging:Hello world! + INFO:forge.util.logging: Hello world! Returns: logging.Logger: The logger. """ logger = logging.getLogger(__name__) if not logger.hasHandlers(): - logger.addHandler(logging.StreamHandler()) + handler = logging.StreamHandler() + formatter = logging.Formatter("%(levelname)s:%(name)s: %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) if level is not None: level = getattr(logging, level.upper()) logger.setLevel(level) diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index ba8992d20..1fe180a35 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -178,7 +178,7 @@ class WandBLogger(MetricLogger): If int, all metrics will be logged at this frequency. If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0` log_dir (str | None): WandB log directory. - project (str): WandB project name. Default is `torchtune`. + project (str): WandB project name. Default is `torchforge`. entity (str | None): WandB entity name. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. group (str | None): WandB group name for grouping runs together. If you don't @@ -205,7 +205,7 @@ class WandBLogger(MetricLogger): def __init__( self, freq: Union[int, Mapping[str, int]], - project: str, + project: str = "torchforge", log_dir: str = "metrics_log", entity: str | None = None, group: str | None = None, diff --git a/tests/README.md b/tests/README.md index d02e49e78..148ab8711 100644 --- a/tests/README.md +++ b/tests/README.md @@ -5,8 +5,8 @@ This directory contains tests for the forge project, including unit tests and in ## Test Structure - `unit_tests/`: Contains unit tests for individual components -- `integration_tests.py`: Contains integration tests that test multiple components together -- `integration_tests_h100.py`: Contains integration tests specifically designed for H100 GPUs, which utilize symmetric memory and float8. +- `integration_tests/`: Contains integration tests that test multiple components together +- `sandbox/`: Contains experimental adhoc scripts used for development and debugging - `assets/`: Contains test assets and fixtures used by the tests ## Running Tests @@ -21,50 +21,49 @@ pip install .[dev] ### Running Integration Tests -To run the integration tests: +To run all integration tests: ```bash -python ./tests/integration_tests.py [--config_dir CONFIG_DIR] [--test TEST] [--ngpu NGPU] +pytest -s tests/integration_tests/ ``` -Arguments: -- `output_dir`: (Required) Directory where test outputs will be stored -- `--config_dir`: (Optional) Directory containing configuration files (default: "./torchtitan/models/llama3/train_configs") -- `--test`: (Optional) Specific test to run, use test names from the `build_test_list()` function (default: "all") -- `--ngpu`: (Optional) Number of GPUs to use for testing (default: 8) +To run a specific integration test file: -Examples: ```bash -# Run all integration tests with 8 GPUs -python ./tests/integration_tests.py ./test_output +pytest -s tests/integration_tests/test_vllm_policy_correctness.py +``` + +To run a specific integration test function: + +```bash +pytest -s tests/integration_tests/test_vllm_policy_correctness.py::test_same_output +``` -# Run a specific test with 4 GPUs -python ./tests/integration_tests.py ./test_output --test default --ngpu 4 +Integration tests support custom options defined in `conftest.py`: +- `--config`: Path to YAML config file for sanity check tests +- `--use_dcp`: Override the YAML config `trainer.use_dcp` field (true/false) -# Run all tests with a custom config directory -python ./tests/integration_tests.py ./test_output --config_dir ./my_configs +Example with options: +```bash +pytest -s tests/integration_tests/ --config ./path/to/config.yaml --use_dcp true ``` ### Running Unit Tests -To run only the unit tests: +To run all unit tests: ```bash pytest -s tests/unit_tests/ ``` -### Running Specific Unit Test Files - -To run a specific test file: +To run a specific unit test file: ```bash -pytest -s tests/unit_tests/test_job_config.py +pytest -s tests/unit_tests/test_config.py ``` -### Running Specific Test Functions in Unit Tests - -To run a specific test function: +To run a specific unit test function: ```bash -pytest -s tests/unit_tests/test_job_config.py::TestJobConfig::test_command_line_args +pytest -s tests/unit_tests/test_config.py::test_cache_hit_scenario ``` diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 0b99e75a2..9cc0758fa 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -18,10 +18,11 @@ from forge.actors.trainer import RLTrainer from forge.cli.config import resolve_hf_hub_paths -from forge.controller.provisioner import init_provisioner +from forge.controller.provisioner import get_or_create_provisioner from forge.controller.service.service import uuid from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import resolve_hf_hub_paths from monarch.actor import endpoint from omegaconf import DictConfig, OmegaConf @@ -194,7 +195,7 @@ async def _setup_and_teardown(request): logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}") if cfg.get("provisioner", None) is not None: - await init_provisioner( + await get_or_create_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) await ts.initialize(strategy=ts.ControllerStorageVolumes()) diff --git a/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py b/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py index 4fcd850e7..83e8809a7 100644 --- a/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py +++ b/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py @@ -25,9 +25,9 @@ import torch from forge.actors.reference_model import ReferenceModel -from forge.cli.config import _resolve_hf_model_path from forge.controller import ForgeActor from forge.controller.provisioner import shutdown +from forge.util.config import _resolve_hf_model_path from monarch.actor import endpoint from torchtitan.config.job_config import Checkpoint, Compile, Model, Parallelism from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/tests/integration_tests/test_vllm_policy_correctness.py b/tests/integration_tests/test_vllm_policy_correctness.py index e2da9b068..71ff3677b 100644 --- a/tests/integration_tests/test_vllm_policy_correctness.py +++ b/tests/integration_tests/test_vllm_policy_correctness.py @@ -6,7 +6,7 @@ import pytest -from forge.actors.policy import Policy +from forge.actors.generator import Generator as Policy from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import RequestOutputKind diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py index e5ee6fddd..34dd13107 100644 --- a/tests/sandbox/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -11,9 +11,8 @@ import torch import torchstore as ts from forge.actors.trainer import RLTrainer -from forge.cli.config import parse from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY -from forge.controller.provisioner import init_provisioner, shutdown +from forge.controller.provisioner import get_or_create_provisioner, shutdown from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.perf_tracker import Tracer from forge.types import ( @@ -23,6 +22,7 @@ ProvisionerConfig, ServiceConfig, ) +from forge.util.config import parse from omegaconf import DictConfig from vllm.transformers_utils.tokenizer import get_tokenizer @@ -164,7 +164,7 @@ async def main(cfg: DictConfig): trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1) dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1 - await init_provisioner( + await get_or_create_provisioner( ProvisionerConfig( launcher_config=LauncherConfig( launcher=cfg.get(LAUNCHER_KEY, Launcher.SLURM.value), diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 0668f8eca..01a0f3936 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -18,13 +18,13 @@ from forge.actors._torchstore_utils import get_param_key from forge.actors.generator import Generator from forge.actors.replay_buffer import ReplayBuffer -from forge.cli.config import parse from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown from forge.losses.grpo_loss import SimpleGRPOLoss from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce +from forge.util.config import parse from forge.util.ops import selective_log_softmax from monarch.actor import endpoint from omegaconf import DictConfig diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index 57ccd97b5..eae50c2db 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -95,12 +95,16 @@ async def main(): } service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(config) # Spawn services first (triggers registrations via provisioner hook) - trainer = await TrainActor.options(**service_config).as_service() - generator = await GeneratorActor.options(**service_config).as_service() + trainer = await TrainActor.options( + **service_config, mesh_name="TrainActor" + ).as_service() + generator = await GeneratorActor.options( + **service_config, mesh_name="GeneratorActor" + ).as_service() for i in range(3): print(f"\n=== Global Step {i} ===") diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 0d4652a6b..2c1a0d5e4 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -14,13 +14,13 @@ import os from forge.actors.generator import Generator -from forge.cli.config import parse -from forge.controller.provisioner import init_provisioner, shutdown +from forge.controller.provisioner import get_or_create_provisioner, shutdown from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse from omegaconf import DictConfig os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" @@ -29,11 +29,11 @@ async def run(cfg: DictConfig): if cfg.get("provisioner", None) is not None: - await init_provisioner( + await get_or_create_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) if (prompt := cfg.get("prompt")) is None: diff --git a/src/forge/cli/__init__.py b/tests/unit_tests/observability/__init__.py similarity index 100% rename from src/forge/cli/__init__.py rename to tests/unit_tests/observability/__init__.py diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py index e8900392c..d256c5d7c 100644 --- a/tests/unit_tests/observability/conftest.py +++ b/tests/unit_tests/observability/conftest.py @@ -9,32 +9,7 @@ from unittest.mock import MagicMock, patch import pytest -from forge.observability.metrics import LoggerBackend, MetricCollector - - -class MockBackend(LoggerBackend): - """Mock backend for testing metrics logging without external dependencies.""" - - def __init__(self, logger_backend_config=None): - super().__init__(logger_backend_config or {}) - self.logged_metrics = [] - self.init_called = False - self.finish_called = False - self.metadata = {} - - async def init(self, role="local", primary_logger_metadata=None): - self.init_called = True - self.role = role - self.primary_logger_metadata = primary_logger_metadata or {} - - async def log(self, metrics, step): - self.logged_metrics.append((metrics, step)) - - async def finish(self): - self.finish_called = True - - def get_metadata_for_secondary_ranks(self): - return self.metadata +from forge.observability.metrics import MetricCollector @pytest.fixture(autouse=True) diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py new file mode 100644 index 000000000..1c315b2e9 --- /dev/null +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Optimized unit tests for metric actors functionality.""" + +import pytest + +from forge.observability.metric_actors import ( + get_or_create_metric_logger, + GlobalLoggingActor, + LocalFetcherActor, +) +from monarch.actor import this_host + + +@pytest.fixture +def global_logger(): + """Create a GlobalLoggingActor for testing.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestGlobalLogger", GlobalLoggingActor) + + +@pytest.fixture +def local_fetcher(global_logger): + """Create a LocalFetcherActor linked to global logger.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger) + + +class TestBasicOperations: + """Test basic operations for actors.""" + + @pytest.mark.asyncio + async def test_local_fetcher_flush(self, local_fetcher): + """Test LocalFetcherActor flush operations.""" + result_with_state = await local_fetcher.flush.call_one( + global_step=1, return_state=True + ) + assert result_with_state == {} + + result_without_state = await local_fetcher.flush.call_one( + global_step=1, return_state=False + ) + assert result_without_state == {} + + @pytest.mark.asyncio + async def test_global_logger_basic_ops(self, global_logger): + """Test GlobalLoggingActor basic operations.""" + count = await global_logger.get_fetcher_count.call_one() + assert count >= 0 + + has_fetcher = await global_logger.has_fetcher.call_one("nonexistent") + assert has_fetcher is False + + # Global logger flush (should not raise error) + await global_logger.flush.call_one(global_step=1) + + @pytest.mark.asyncio + async def test_backend_init(self, local_fetcher): + """Test backend initialization and shutdown.""" + metadata = {"wandb": {"shared_run_id": "test123"}} + config = {"console": {"reduce_across_ranks": False}} + + await local_fetcher.init_backends.call_one(metadata, config, global_step=5) + await local_fetcher.shutdown.call_one() + + +class TestRegistrationLifecycle: + """Test registration lifecycle.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_registration_lifecycle(self, global_logger, local_fetcher): + """Test complete registration/deregistration lifecycle.""" + proc_name = "lifecycle_test_proc" + + # Initial state + initial_count = await global_logger.get_fetcher_count.call_one() + assert await global_logger.has_fetcher.call_one(proc_name) is False + + # Register + await global_logger.register_fetcher.call_one(local_fetcher, proc_name) + + # Verify registered + new_count = await global_logger.get_fetcher_count.call_one() + assert new_count == initial_count + 1 + assert await global_logger.has_fetcher.call_one(proc_name) is True + + # Deregister + await global_logger.deregister_fetcher.call_one(proc_name) + + # Verify deregistered + final_count = await global_logger.get_fetcher_count.call_one() + assert final_count == initial_count + assert await global_logger.has_fetcher.call_one(proc_name) is False + + +class TestBackendConfiguration: + """Test backend configuration validation.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_valid_backend_configs(self, global_logger): + """Test valid backend configurations.""" + # Empty config + await global_logger.init_backends.call_one({}) + + # Valid configs for different reduce_across_ranks modes + for reduce_across_ranks in [True, False]: + config = {"console": {"reduce_across_ranks": reduce_across_ranks}} + await global_logger.init_backends.call_one(config) + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_invalid_backend_configs(self, global_logger): + """Test invalid backend configurations are handled gracefully.""" + # Empty config should work + await global_logger.init_backends.call_one({}) + + # Config with only project should work + config_with_project = {"console": {"project": "test_project"}} + await global_logger.init_backends.call_one(config_with_project) + + # Config with reduce_across_ranks should work + config_with_reduce = {"console": {"reduce_across_ranks": True}} + await global_logger.init_backends.call_one(config_with_reduce) + + +class TestErrorHandling: + """Test error handling scenarios.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_deregister_nonexistent_fetcher(self, global_logger): + """Test deregistering non-existent fetcher doesn't crash.""" + await global_logger.deregister_fetcher.call_one("nonexistent_proc") + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_shutdown(self, global_logger): + """Test shutdown without issues.""" + await global_logger.shutdown.call_one() + + +class TestGetOrCreateMetricLogger: + """Test the integration function.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_get_or_create_functionality(self): + """Test get_or_create_metric_logger basic functionality.""" + result = await get_or_create_metric_logger(process_name="TestController") + + # Should return a GlobalLoggingActor mesh + assert result is not None + + # Should be able to call basic methods + count = await result.get_fetcher_count.call_one() + assert count >= 0 diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 701bda2dc..b4a8ffcdf 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -80,12 +80,9 @@ def test_new_enums_and_constants(self): assert isinstance(BackendRole.LOCAL, BackendRole) assert isinstance(BackendRole.GLOBAL, BackendRole) - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_backend_role_usage(self, mock_actor_name): + async def test_backend_role_usage(self): """Test that BackendRole constants are actually used instead of string literals.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - # Test ConsoleBackend console_backend = ConsoleBackend({}) await console_backend.init(role=BackendRole.LOCAL) @@ -295,10 +292,8 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): mock_collector_class.assert_called_once() mock_collector.push.assert_called_once() - @patch("forge.observability.metrics.get_actor_name_with_rank") - def test_wandb_backend_creation(self, mock_actor_name): + def test_wandb_backend_creation(self): """Test WandbBackend creation and basic setup without WandB dependency.""" - mock_actor_name.return_value = "TestActor_abcd_r0" config = { "project": "test_project", @@ -316,12 +311,9 @@ def test_wandb_backend_creation(self, mock_actor_name): metadata = backend.get_metadata_for_secondary_ranks() assert metadata == {} # Should be empty when no run - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_console_backend(self, mock_actor_name): + async def test_console_backend(self): """Test ConsoleBackend basic operations.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - backend = ConsoleBackend({}) await backend.init(role=BackendRole.LOCAL) @@ -425,28 +417,33 @@ async def _test_fetcher_registration(self, env_var_value, should_register_fetche if hasattr(procs, "_local_fetcher"): delattr(procs, "_local_fetcher") - # Test functionality - global_logger = await get_or_create_metric_logger(proc_mesh=procs) + # Test functionality - pass explicit process_name since test bypasses provisioner + global_logger = await get_or_create_metric_logger( + proc_mesh=procs, process_name="TestProcess" + ) # Get results to check proc_has_fetcher = hasattr(procs, "_local_fetcher") - global_has_fetcher = await global_logger.has_fetcher.call_one(procs) + proc_id = procs._uid if hasattr(procs, "_uid") else None + global_has_fetcher = ( + await global_logger.has_fetcher.call_one(proc_id) if proc_id else False + ) # Assert based on expected behavior if should_register_fetchers: assert ( proc_has_fetcher - ), f"Expected process to have _local_fetcher when {env_var_value=}" + ), f"Expected process to have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}" assert ( global_has_fetcher - ), f"Expected global logger to have fetcher registered when {env_var_value=}" + ), f"Expected global logger to have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}" else: assert ( not proc_has_fetcher - ), f"Expected process to NOT have _local_fetcher when {env_var_value=}" + ), f"Expected process to NOT have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}" assert ( not global_has_fetcher - ), f"Expected global logger to NOT have fetcher registered when {env_var_value=}" + ), f"Expected global logger to NOT have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}" @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/unit_tests/observability/test_utils.py b/tests/unit_tests/observability/test_utils.py new file mode 100644 index 000000000..6b173cc42 --- /dev/null +++ b/tests/unit_tests/observability/test_utils.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for observability utility functions.""" + +from forge.controller.actor import ForgeActor + +from forge.observability.utils import get_proc_name_with_rank +from monarch.actor import endpoint + + +class UtilActor(ForgeActor): + """Actor for testing get_proc_name_with_rank in spawned context.""" + + @endpoint + async def get_name(self) -> str: + return get_proc_name_with_rank() + + +class TestGetProcNameWithRank: + """Tests for get_proc_name_with_rank utility.""" + + def test_direct_proc(self): + """Direct proc should return 'client_r0'.""" + assert get_proc_name_with_rank() == "client_r0" + + def test_direct_proc_with_override(self): + """Direct proc with override should use provided name.""" + result = get_proc_name_with_rank(proc_name="MyProcess") + assert result == "MyProcess_r0" + + # TODO (felipemello): currently not working with CI wheel, but passes locally + # reactive once wheel is updated with new monarch version + # @pytest.mark.timeout(10) + # @pytest.mark.asyncio + # async def test_replicas(self): + # """Test service with replicas returns unique names and hashes per replica.""" + # actor = await UtilActor.options( + # procs=1, num_replicas=2, with_gpus=False + # ).as_service() + # results = await actor.get_name.fanout() + + # assert len(results) == 2 + # assert len(set(results)) == 2 # All names are unique + # for name in results: + # assert name.startswith("UtilActor") + # assert name.endswith("_r0") + + # # Extract hashes from names (format: ActorName_replicaIdx_hash_r0) + # hashes = [name.split("_")[-2] for name in results] + # assert hashes[0] != hashes[1] # Hashes are different between replicas diff --git a/tests/unit_tests/rl/environments/__init__.py b/tests/unit_tests/rl/environments/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/tests/unit_tests/rl/environments/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/rl/environments/test_chat.py b/tests/unit_tests/rl/environments/test_chat.py deleted file mode 100644 index 4abf89dc6..000000000 --- a/tests/unit_tests/rl/environments/test_chat.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from typing import Any, Optional -from unittest.mock import MagicMock - -import torch - -from forge.envs.chat import ( - ChatAction, - ChatEnvironment, - ChatObservation, - ChatState, - Message, -) - - -class MockTokenizer: - """Mock tokenizer implementing TokenizerProtocol for testing.""" - - def apply_chat_template( - self, - conversation: list[dict[str, str]], - tools: Optional[list[dict]] = None, - documents: Optional[list[dict[str, str]]] = None, - chat_template: Optional[str] = None, - add_generation_prompt: bool = False, - continue_final_message: bool = False, - tokenize: bool = True, - padding: bool = False, - truncation: bool = False, - max_length: Optional[int] = None, - return_tensors: Optional[str] = None, - return_dict: bool = False, - return_assistant_tokens_mask: bool = False, - tokenizer_kwargs: Optional[dict[str, Any]] = None, - **kwargs, - ) -> torch.Tensor: - """Mock implementation of apply_chat_template.""" - # For testing, we'll just return a tensor with a simple pattern based on the conversation - # Each message contributes 10 tokens to the output - return torch.tensor([[i for i in range(len(conversation) * 10)]]) - - def decode( - self, - token_ids: Any, - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: Optional[bool] = None, - **kwargs, - ) -> str: - """Mock implementation of decode.""" - # For testing, we'll just convert the tensor to a string - if isinstance(token_ids, torch.Tensor): - return f"Decoded: {token_ids.tolist()}" - return f"Decoded: {token_ids}" - - -class TestChatAction(unittest.TestCase): - """Test the ChatAction class.""" - - def test_init(self): - """Test initialization of ChatAction.""" - tokens = torch.tensor([1, 2, 3]) - action = ChatAction(tokens=tokens) - self.assertTrue(torch.equal(action.tokens, tokens)) - - def test_init_empty_tokens(self): - """Test initialization with empty tokens raises ValueError.""" - with self.assertRaises(ValueError): - ChatAction(tokens=torch.tensor([])) - - -class TestChatState(unittest.TestCase): - """Test the ChatState class.""" - - def test_init(self): - """Test initialization of ChatState.""" - state = ChatState() - self.assertEqual(state.history_messages, []) - self.assertEqual(state.history_tokens, []) - - def test_init_with_values(self): - """Test initialization with provided values.""" - messages: list[Message] = [{"role": "user", "content": "Hello"}] - tokens = [torch.tensor([1, 2, 3])] - state = ChatState(history_messages=messages, history_tokens=tokens) - self.assertEqual(state.history_messages, messages) - self.assertEqual(state.history_tokens, tokens) - - -class TestChatObservation(unittest.TestCase): - """Test the ChatObservation class.""" - - def test_init(self): - """Test initialization of ChatObservation.""" - obs = ChatObservation() - self.assertEqual(obs.messages, []) - self.assertEqual(obs.tokens.numel(), 0) - self.assertFalse(obs.done) - self.assertIsNone(obs.reward) - self.assertEqual(obs.metadata, {}) - - def test_init_with_values(self): - """Test initialization with provided values.""" - messages: list[Message] = [{"role": "user", "content": "Hello"}] - tokens = torch.tensor([1, 2, 3]) - obs = ChatObservation( - messages=messages, - tokens=tokens, - done=True, - reward=1.0, - metadata={"test": "value"}, - ) - self.assertEqual(obs.messages, messages) - self.assertTrue(torch.equal(obs.tokens, tokens)) - self.assertTrue(obs.done) - self.assertEqual(obs.reward, 1.0) - self.assertEqual(obs.metadata, {"test": "value"}) - - -class TestChatEnvironment(unittest.TestCase): - """Test the ChatEnvironment class.""" - - def setUp(self): - """Set up test fixtures.""" - self.tokenizer = MockTokenizer() - - def test_init_no_system_prompt(self): - """Test initialization without system prompt.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - self.assertEqual(env._state.history_messages, []) - self.assertEqual(env._state.history_tokens, []) - - def test_init_with_system_prompt(self): - """Test initialization with system prompt.""" - env = ChatEnvironment( - tokenizer=self.tokenizer, - system_prompt="You are a helpful assistant", - system_role="system", - ) - self.assertEqual(len(env._state.history_messages), 1) - self.assertEqual(env._state.history_messages[0]["role"], "system") - self.assertEqual( - env._state.history_messages[0]["content"], "You are a helpful assistant" - ) - self.assertEqual(len(env._state.history_tokens), 1) - - def test_init_invalid_tokenizer(self): - """Test initialization with invalid tokenizer.""" - # Create a mock with no attributes by setting spec=[] - invalid_tokenizer = MagicMock(spec=[]) - with self.assertRaises(ValueError): - ChatEnvironment(tokenizer=invalid_tokenizer) - - def test_reset_no_system_prompt(self): - """Test reset without system prompt.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - # Add some history first - env._state.history_messages = [{"role": "user", "content": "Hello"}] # type: ignore - env._state.history_tokens = [torch.tensor([1, 2, 3])] - - # Reset should clear the history - obs = env.reset() - self.assertEqual(env._state.history_messages, []) - self.assertEqual(env._state.history_tokens, []) - self.assertEqual(obs.messages, []) - self.assertEqual(obs.tokens.numel(), 0) - - def test_reset_with_system_prompt(self): - """Test reset with system prompt.""" - env = ChatEnvironment( - tokenizer=self.tokenizer, - system_prompt="You are a helpful assistant", - system_role="system", - ) - # Add some history first - env._state.history_messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello"}, - ] # type: ignore - env._state.history_tokens = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] - - # Reset should clear the history and add the system prompt - obs = env.reset() - self.assertEqual(len(env._state.history_messages), 1) - self.assertEqual(env._state.history_messages[0]["role"], "system") - self.assertEqual( - env._state.history_messages[0]["content"], "You are a helpful assistant" - ) - self.assertEqual(len(env._state.history_tokens), 1) - self.assertEqual(len(obs.messages), 1) - self.assertEqual(obs.messages[0]["role"], "system") - self.assertEqual(obs.messages[0]["content"], "You are a helpful assistant") - - def test_step(self): - """Test step method.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - action = ChatAction(tokens=torch.tensor([1, 2, 3])) - - obs = env.step(action) - - # Check that the tokens were added to history - self.assertEqual(len(env._state.history_tokens), 1) - self.assertTrue( - torch.equal(env._state.history_tokens[0], torch.tensor([1, 2, 3])) - ) - - # Check that the message was added to history with decoded content - self.assertEqual(len(env._state.history_messages), 1) - self.assertEqual(env._state.history_messages[0]["role"], "assistant") - self.assertEqual( - env._state.history_messages[0]["content"], "Decoded: [1, 2, 3]" - ) - - # Check the observation - self.assertEqual(len(obs.messages), 1) - self.assertEqual(obs.messages[0]["role"], "assistant") - self.assertEqual(obs.messages[0]["content"], "Decoded: [1, 2, 3]") - - def test_create_observation(self): - """Test _create_observation method.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - env._state.history_messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello"}, - ] # type: ignore - env._state.history_tokens = [ - torch.tensor([[1, 2, 3]]), - torch.tensor([[4, 5, 6]]), - ] - - obs = env._create_observation() - - # Check the observation - self.assertEqual(len(obs.messages), 2) - self.assertEqual(obs.messages[0]["role"], "system") - self.assertEqual(obs.messages[0]["content"], "You are a helpful assistant") - self.assertEqual(obs.messages[1]["role"], "user") - self.assertEqual(obs.messages[1]["content"], "Hello") - - # Check that the tokens were concatenated - self.assertEqual(obs.tokens.numel(), 6) # 2 tensors of size 3 - - def test_create_observation_empty_history(self): - """Test _create_observation method with empty history.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - - obs = env._create_observation() - - # Check the observation - self.assertEqual(obs.messages, []) - self.assertEqual(obs.tokens.numel(), 0) - - def test_state_property(self): - """Test state property.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - state = env.state - self.assertIsInstance(state, ChatState) - self.assertEqual(state.history_messages, []) - self.assertEqual(state.history_tokens, []) - - def test_message_to_action(self): - """Test message_to_action method.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - message: Message = {"role": "user", "content": "Hello"} - - action = env.message_to_action(message) - - self.assertIsInstance(action, ChatAction) - self.assertEqual( - action.tokens.numel(), 10 - ) # Mock tokenizer returns 10 tokens per message - - def test_message_to_action_missing_role(self): - """Test message_to_action method with missing role.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - # We're intentionally creating an invalid message to test error handling - message = {"content": "Hello"} # type: ignore - - with self.assertRaises(ValueError): - # Using type: ignore because we're intentionally passing an invalid message - env.message_to_action(message) # type: ignore - - def test_message_to_action_missing_content(self): - """Test message_to_action method with missing content.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - # We're intentionally creating an invalid message to test error handling - message = {"role": "user"} # type: ignore - - with self.assertRaises(ValueError): - # Using type: ignore because we're intentionally passing an invalid message - env.message_to_action(message) # type: ignore - - def test_message_to_action_none_content(self): - """Test message_to_action method with None content.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - # We're intentionally creating an invalid message to test error handling - message = {"role": "user", "content": None} # type: ignore - - with self.assertRaises(ValueError): - # Using type: ignore because we're intentionally passing an invalid message - env.message_to_action(message) # type: ignore - - def test_with_transform(self): - """Test environment with a transform.""" - - def transform(obs): - obs.metadata["transformed"] = True - obs.reward = 1.0 - return obs - - env = ChatEnvironment(tokenizer=self.tokenizer, transform=transform) - action = ChatAction(tokens=torch.tensor([1, 2, 3])) - - obs = env.step(action) - - self.assertTrue(obs.metadata.get("transformed")) - self.assertEqual(obs.reward, 1.0) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 69cc7e2ed..64a00c759 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -8,7 +8,7 @@ import pytest -from forge.cli.config import resolve_hf_hub_paths +from forge.util.config import resolve_hf_hub_paths from omegaconf import DictConfig, OmegaConf @@ -39,7 +39,7 @@ ({"level1": {"level2": {"model": "hf://deep/model"}}}, [("deep/model",)]), ], ) -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_hf_path_resolution(mock_download, config_data, expected_calls): """Test hf:// path resolution in various data structures.""" mock_download.return_value = "/fake/cache/model" @@ -78,7 +78,7 @@ def test_non_hf_paths_unchanged(config_data): # Cache behavior tests -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_cache_hit_scenario(mock_download): """Test behavior when model is already cached.""" mock_download.return_value = "/fake/cache/model" @@ -93,7 +93,7 @@ def test_cache_hit_scenario(mock_download): assert result.model == "/fake/cache/model" -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_cache_miss_scenario(mock_download): """Test behavior when model is not cached.""" from huggingface_hub.utils import LocalEntryNotFoundError @@ -145,7 +145,7 @@ def test_invalid_hf_urls(invalid_hf_url, expected_error): assert expected_error in str(exc_info.value) -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_download_failure_handling(mock_download): """Test error handling when download fails.""" mock_download.side_effect = Exception("Network error: Repository not found") @@ -159,7 +159,7 @@ def test_download_failure_handling(mock_download): # Integration test with mixed data types -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_complex_real_world_config(mock_download): """Test with a realistic complex configuration.""" mock_download.return_value = "/fake/cache/model" diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index e6c6876c3..10053b78f 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -6,10 +6,25 @@ """Test for data/replay_buffer.py""" +from dataclasses import dataclass + import pytest import pytest_asyncio from forge.actors.replay_buffer import ReplayBuffer -from forge.types import Trajectory + + +@dataclass +class TestEpisode: + """ + Dummy Episode containing just a policy version + + ReplayBuffer expects any construct (typically an Episode) that contains a + `policy_version`. + + TODO: Replaced with a unified interface in the future. + """ + + policy_version: int class TestReplayBuffer: @@ -23,27 +38,27 @@ async def replay_buffer(self) -> ReplayBuffer: @pytest.mark.asyncio async def test_add(self, replay_buffer: ReplayBuffer) -> None: - trajectory = Trajectory(policy_version=0) - await replay_buffer.add.call_one(trajectory) + episode = TestEpisode(policy_version=0) + await replay_buffer.add.call_one(episode) assert replay_buffer._numel.call_one().get() == 1 - assert replay_buffer._getitem.call_one(0).get() == trajectory + assert replay_buffer._getitem.call_one(0).get() == episode replay_buffer.clear.call_one().get() @pytest.mark.asyncio async def test_add_multiple(self, replay_buffer) -> None: - trajectory_0 = Trajectory(policy_version=0) - trajectory_1 = Trajectory(policy_version=1) - await replay_buffer.add.call_one(trajectory_0) - await replay_buffer.add.call_one(trajectory_1) + episode_0 = TestEpisode(policy_version=0) + episode_1 = TestEpisode(policy_version=1) + await replay_buffer.add.call_one(episode_0) + await replay_buffer.add.call_one(episode_1) assert replay_buffer._numel.call_one().get() == 2 - assert replay_buffer._getitem.call_one(0).get() == trajectory_0 - assert replay_buffer._getitem.call_one(1).get() == trajectory_1 + assert replay_buffer._getitem.call_one(0).get() == episode_0 + assert replay_buffer._getitem.call_one(1).get() == episode_1 replay_buffer.clear.call_one().get() @pytest.mark.asyncio async def test_state_dict_save_load(self, replay_buffer) -> None: - trajectory = Trajectory(policy_version=0) - await replay_buffer.add.call_one(trajectory) + episode = TestEpisode(policy_version=0) + await replay_buffer.add.call_one(episode) state_dict = replay_buffer.state_dict.call_one().get() replay_buffer.clear.call_one().get() assert replay_buffer._numel.call_one().get() == 0 @@ -53,10 +68,10 @@ async def test_state_dict_save_load(self, replay_buffer) -> None: @pytest.mark.asyncio async def test_evict(self, replay_buffer) -> None: - trajectory_0 = Trajectory(policy_version=0) - trajectory_1 = Trajectory(policy_version=1) - await replay_buffer.add.call_one(trajectory_0) - await replay_buffer.add.call_one(trajectory_1) + episode_0 = TestEpisode(policy_version=0) + episode_1 = TestEpisode(policy_version=1) + await replay_buffer.add.call_one(episode_0) + await replay_buffer.add.call_one(episode_1) assert replay_buffer._numel.call_one().get() == 2 await replay_buffer.evict.call_one(curr_policy_version=2) assert replay_buffer._numel.call_one().get() == 1 @@ -64,10 +79,10 @@ async def test_evict(self, replay_buffer) -> None: @pytest.mark.asyncio async def test_sample(self, replay_buffer) -> None: - trajectory_0 = Trajectory(policy_version=0) - trajectory_1 = Trajectory(policy_version=1) - await replay_buffer.add.call_one(trajectory_0) - await replay_buffer.add.call_one(trajectory_1) + episode_0 = TestEpisode(policy_version=0) + episode_1 = TestEpisode(policy_version=1) + await replay_buffer.add.call_one(episode_0) + await replay_buffer.add.call_one(episode_1) assert replay_buffer._numel.call_one().get() == 2 # Test a simple sampling @@ -77,19 +92,19 @@ async def test_sample(self, replay_buffer) -> None: assert replay_buffer._numel.call_one().get() == 2 # Test sampling (not enough samples in buffer, returns None) - await replay_buffer.add.call_one(trajectory_0) + await replay_buffer.add.call_one(episode_0) samples = await replay_buffer.sample.call_one(curr_policy_version=1) assert samples is None replay_buffer.clear.call_one().get() @pytest.mark.asyncio async def test_sample_with_evictions(self, replay_buffer) -> None: - trajectory_0 = Trajectory(policy_version=0) - trajectory_1 = Trajectory(policy_version=1) - trajectory_2 = Trajectory(policy_version=2) - await replay_buffer.add.call_one(trajectory_0) - await replay_buffer.add.call_one(trajectory_1) - await replay_buffer.add.call_one(trajectory_2) + episode_0 = TestEpisode(policy_version=0) + episode_1 = TestEpisode(policy_version=1) + episode_2 = TestEpisode(policy_version=2) + await replay_buffer.add.call_one(episode_0) + await replay_buffer.add.call_one(episode_1) + await replay_buffer.add.call_one(episode_2) assert replay_buffer._numel.call_one().get() == 3 samples = await replay_buffer.sample.call_one( curr_policy_version=2, @@ -112,8 +127,8 @@ async def test_sample_dp_size(self) -> None: # Add enough trajectories to sample for i in range(10): - trajectory = Trajectory(policy_version=0) - await replay_buffer.add.call_one(trajectory) + episode = TestEpisode(policy_version=0) + await replay_buffer.add.call_one(episode) # Sample and verify len(samples) == dp_size samples = await replay_buffer.sample.call_one(curr_policy_version=0) diff --git a/tests/unit_tests/util/test_shared_tensor.py b/tests/unit_tests/util/test_shared_tensor.py new file mode 100644 index 000000000..f922c3733 --- /dev/null +++ b/tests/unit_tests/util/test_shared_tensor.py @@ -0,0 +1,905 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +import time + +import pytest +import torch + +# Assuming SharedTensor is in shared_tensor.py +from forge.util._shared_tensor import SharedTensor +from multiprocess import Process, Queue + + +class TestSharedTensorCreation: + """Test tensor creation methods""" + + def test_empty_creation(self): + """Test creating empty tensor""" + shape = (100, 200) + dtype = torch.float32 + + shared = SharedTensor.empty(shape, dtype) + + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.dtype == dtype + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.dtype == dtype + + shared.drop() + + def test_empty_with_bfloat16(self): + """Test creating empty bfloat16 tensor""" + shape = (50, 50) + shared = SharedTensor.empty(shape, torch.bfloat16) + + assert shared.tensor.dtype == torch.bfloat16 + assert shared.tensor.dtype == torch.bfloat16 + + shared.drop() + + def test_zeros_creation(self): + """Test creating zero-initialized tensor""" + shape = (10, 20) + shared = SharedTensor.zeros(shape, torch.float32) + + tensor = shared.tensor + assert torch.all(tensor == 0) + assert tensor.sum().item() == 0.0 + + shared.drop() + + def test_ones_creation(self): + """Test creating ones-initialized tensor""" + shape = (10, 20) + shared = SharedTensor.ones(shape, torch.float32) + + tensor = shared.tensor + assert torch.all(tensor == 1) + assert tensor.sum().item() == 200.0 + + shared.drop() + + def test_from_tensor_creation(self): + """Test creating from existing tensor""" + original = torch.randn(50, 50) + shared = SharedTensor(tensor=original) + + assert shared.tensor.shape == original.shape + assert shared.tensor.dtype == original.dtype + assert torch.allclose(shared.tensor, original) + + shared.drop() + + def test_from_handle_creation(self): + """Test creating from handle""" + # Create original + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(5.0) + + # Get handle + handle = original.get_handle() + + # Create from handle + reconstructed = SharedTensor(handle=handle) + + assert torch.all(reconstructed.tensor == 5.0) + assert reconstructed.tensor.shape == original.tensor.shape + assert reconstructed.tensor.dtype == original.tensor.dtype + + original.drop() + + def test_creation_requires_argument(self): + """Test that creation without arguments raises error""" + with pytest.raises(ValueError, match="Must provide either tensor or handle"): + SharedTensor() + + @pytest.mark.parametrize( + "shape", + [ + (10,), + (10, 20), + (5, 10, 15), + (2, 3, 4, 5), + ], + ) + def test_various_shapes(self, shape): + """Test creation with various shapes""" + shared = SharedTensor.empty(shape, torch.float32) + assert shared.tensor.shape == torch.Size(shape) + assert shared.tensor.shape == torch.Size(shape) + shared.drop() + + +class TestSharedTensorDtypes: + """Test all supported dtypes""" + + @pytest.mark.parametrize( + "dtype", + [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + torch.int16, + torch.int8, + torch.uint8, + torch.bool, + ], + ) + def test_all_dtypes(self, dtype): + """Test that all dtypes work correctly""" + shape = (10, 10) + shared = SharedTensor.empty(shape, dtype) + + assert shared.tensor.dtype == dtype + assert shared.tensor.dtype == dtype + + # Test that we can write to it + if dtype == torch.bool: + shared.tensor.fill_(True) + elif dtype in [torch.int32, torch.int64, torch.int16, torch.int8, torch.uint8]: + shared.tensor.fill_(42) + else: + shared.tensor.fill_(3.14) + + shared.drop() + + def test_dtype_conversion_in_handle(self): + """Test dtype is preserved through handle""" + for dtype in [torch.float32, torch.bfloat16, torch.int64]: + shared1 = SharedTensor.empty((5, 5), dtype) + handle = shared1.get_handle() + + shared2 = SharedTensor(handle=handle) + assert shared2.tensor.dtype == dtype + + shared1.drop() + + +class TestSharedTensorOperations: + """Test tensor operations""" + + def test_copy_from(self): + """Test copying data from another tensor""" + source = torch.randn(20, 30) + shared = SharedTensor.empty((20, 30), torch.float32) + + shared.copy_from(source) + + assert torch.allclose(shared.tensor, source) + shared.drop() + + def test_copy_from_shape_mismatch(self): + """Test copy_from raises error on shape mismatch""" + source = torch.randn(10, 10) + shared = SharedTensor.empty((20, 20), torch.float32) + + with pytest.raises(ValueError, match="Shape mismatch"): + shared.copy_from(source) + + shared.drop() + + def test_clone(self): + """Test cloning creates independent copy""" + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(5.0) + + cloned = original.clone() + + # Verify data is same + assert torch.all(cloned.tensor == 5.0) + + # Verify they're independent + original.tensor.fill_(10.0) + assert torch.all(cloned.tensor == 5.0) + assert torch.all(original.tensor == 10.0) + + original.drop() + cloned.drop() + + def test_tensor_modifications(self): + """Test that modifications to tensor are reflected""" + shared = SharedTensor.zeros((10, 10), torch.float32) + tensor = shared.tensor + + tensor[0, 0] = 99.0 + tensor[5:, :] = 42.0 + + # Get tensor again and verify changes persist + tensor2 = shared.tensor + assert tensor2[0, 0].item() == 99.0 + assert torch.all(tensor2[5:, :] == 42.0) + + shared.drop() + + def test_inplace_operations(self): + """Test in-place operations work""" + shared = SharedTensor.empty((100, 100), torch.float32) + tensor = shared.tensor + + tensor.normal_(0, 1) + mean = tensor.mean().item() + + tensor.add_(5.0) + new_mean = tensor.mean().item() + + assert abs(new_mean - (mean + 5.0)) < 0.1 + + shared.drop() + + +class TestSharedTensorSerialization: + """Test pickling and handle serialization""" + + def test_handle_is_picklable(self): + """Test that handle can be pickled""" + shared = SharedTensor.empty((10, 10), torch.float32) + handle = shared.get_handle() + + # Pickle and unpickle + pickled = pickle.dumps(handle) + unpickled_handle = pickle.loads(pickled) + + assert unpickled_handle == handle + + shared.drop() + + def test_handle_small_size(self): + """Test that handle is small (efficient for RPC)""" + shared = SharedTensor.empty((10000, 10000), torch.float32) + handle = shared.get_handle() + + pickled = pickle.dumps(handle) + + # Handle should be < 1KB even for huge tensors + assert len(pickled) < 1024 + + shared.drop() + + def test_data_integrity_after_pickle(self): + """Test data is preserved through handle pickling""" + # Create and fill tensor + shared1 = SharedTensor.empty((50, 50), torch.bfloat16) + shared1.tensor.normal_(0, 1) + original_data = shared1.tensor.clone() + + # Pickle handle + handle = shared1.get_handle() + pickled = pickle.dumps(handle) + unpickled_handle = pickle.loads(pickled) + + # Reconstruct + shared2 = SharedTensor(handle=unpickled_handle) + + # Verify data is same + assert torch.allclose(shared2.tensor.float(), original_data.float(), rtol=1e-3) + + shared1.drop() + + +class TestSharedTensorMemory: + """Test memory management and cleanup""" + + def test_drop(self): + """Test drop removes shared memory""" + shared = SharedTensor.empty((10, 10), torch.float32) + shm_name = shared._shm_name + + # Verify shared memory exists + tensor = shared.tensor + tensor.fill_(5.0) + + # Drop shared memory + shared.drop() + + # Trying to attach should fail + from multiprocessing import shared_memory + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=shm_name) + + def test_multiple_views_same_memory(self): + """Test multiple tensor views point to same memory""" + shared = SharedTensor.empty((10, 10), torch.float32) + + tensor1 = shared.tensor + tensor1.fill_(5.0) + + tensor2 = shared.tensor + assert torch.all(tensor2 == 5.0) + + # Modify through tensor2 + tensor2.fill_(10.0) + + # Verify tensor1 sees the change + assert torch.all(tensor1 == 10.0) + + shared.drop() + + def test_handle_reconstruction_shares_memory(self): + """Test that handle reconstruction shares same memory""" + shared1 = SharedTensor.empty((20, 20), torch.float32) + shared1.tensor.fill_(7.0) + + handle = shared1.get_handle() + shared2 = SharedTensor(handle=handle) + + # Modify through shared2 + shared2.tensor.fill_(14.0) + + # Verify shared1 sees the change + assert torch.all(shared1.tensor == 14.0) + + shared1.drop() + + +class TestSharedTensorEdgeCases: + """Test edge cases and error conditions""" + + def test_empty_shape(self): + """Test scalar tensor (empty shape)""" + shared = SharedTensor.ones((), torch.float32) + assert shared.tensor.shape == () + assert shared.tensor.numel() == 1 + assert torch.allclose( + shared.tensor, + torch.ones( + (), + ), + ) + shared.drop() + + def test_single_element_tensor(self): + """Test 1-element tensor""" + shared = SharedTensor.empty((1,), torch.float32) + shared.tensor.fill_(42.0) + assert shared.tensor.item() == 42.0 + shared.drop() + + def test_large_tensor(self): + """Test large tensor (1GB)""" + # 1GB tensor: 250M float32 elements + shape = (250_000_000,) + shared = SharedTensor.empty(shape, torch.float32) + + assert shared.tensor.shape == shape + assert shared.tensor.numel() == 250_000_000 + + shared.drop() + + def test_non_contiguous_tensor_conversion(self): + """Test that non-contiguous tensors are handled""" + # Create non-contiguous tensor + original = torch.randn(10, 10).t() # Transpose makes it non-contiguous + assert not original.is_contiguous() + + # Should work (internally makes contiguous) + shared = SharedTensor(tensor=original) + + # Result should match + assert torch.allclose(shared.tensor, original) + + shared.drop() + + def test_repr(self): + """Test string representation""" + shared = SharedTensor.empty((10, 20), torch.float32) + repr_str = repr(shared) + + assert "SharedTensor" in repr_str + assert "10, 20" in repr_str + assert "float32" in repr_str + assert shared._shm_name in repr_str + + shared.drop() + + +class TestSharedTensorMultiprocess: + """Test multiprocess scenarios""" + + def test_multiprocess_read(self): + """Test reading shared tensor from another process""" + + def reader_process(handle_dict, result_queue): + with SharedTensor(handle=handle_dict) as shared: + result_queue.put(shared.tensor.sum().item()) + + # Create shared tensor in main process + shared = SharedTensor.empty((100, 100), torch.float32) + shared.tensor.fill_(5.0) + + # Read from child process + result_queue = Queue() + handle = shared.get_handle() + + p = Process(target=reader_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + expected = 5.0 * 100 * 100 + + assert abs(result - expected) < 1e-5 + + shared.drop() + + def test_multiprocess_write(self): + """Test writing to shared tensor from another process""" + + def writer_process(handle_dict, value): + with SharedTensor(handle=handle_dict) as shared: + shared.tensor.fill_(value) + + # Create empty shared tensor + shared = SharedTensor.empty((50, 50), torch.float32) + shared.tensor.zero_() + + # Write from child process + handle = shared.get_handle() + + p = Process(target=writer_process, args=(handle, 42.0)) + p.start() + p.join() + + # Verify in main process + assert torch.all(shared.tensor == 42.0) + + shared.drop() + + def test_multiprocess_bidirectional(self): + """Test bidirectional communication""" + + def worker_process(input_handle, output_handle): + with SharedTensor(handle=input_handle) as input_shared: + with SharedTensor(handle=output_handle) as output_shared: + # Compute: output = input * 2 + output_shared.tensor.copy_(input_shared.tensor * 2) + + # Create input and output tensors + input_shared = SharedTensor.empty((100, 100), torch.float32) + input_shared.tensor.normal_(0, 1) + input_data = input_shared.tensor.clone() + + output_shared = SharedTensor.empty((100, 100), torch.float32) + + # Process in child + p = Process( + target=worker_process, + args=(input_shared.get_handle(), output_shared.get_handle()), + ) + p.start() + p.join() + + # Verify result + expected = input_data * 2 + assert torch.allclose( + output_shared.tensor, expected + ), "output: {}, expected: {}".format(output_shared.tensor, expected) + + input_shared.drop() + output_shared.drop() + + +class TestSharedTensorPerformance: + """Performance-related tests""" + + def test_empty_faster_than_from_tensor(self): + """Test that empty() is faster than from tensor""" + shape = (1000, 1000) + + # Time empty creation + start = time.time() + for _ in range(10): + shared = SharedTensor.empty(shape, torch.float32) + shared.drop() + empty_time = time.time() - start + + # Time from_tensor creation + start = time.time() + for _ in range(10): + tensor = torch.randn(shape) + shared = SharedTensor(tensor=tensor) + shared.drop() + from_tensor_time = time.time() - start + + # empty() should be faster (no data copying) + assert empty_time < from_tensor_time + + def test_handle_serialization_fast(self): + """Test that handle serialization is fast""" + shared = SharedTensor.empty((10000, 10000), torch.float32) + handle = shared.get_handle() + + start = time.time() + for _ in range(1000): + pickled = pickle.dumps(handle) + unpickled = pickle.loads(pickled) + elapsed = time.time() - start + + # Should be able to do 1000 round trips in < 0.1 seconds + assert elapsed < 0.1 + + shared.drop() + + +class TestSharedTensorHandleToSharedTensor: + """Test SharedTensorHandle.to_shared_tensor() method""" + + def test_to_shared_tensor_basic(self): + """Test basic creation of SharedTensor from handle using to_shared_tensor method""" + original = SharedTensor.empty((10, 10), torch.float32) + original.tensor.fill_(7.0) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert torch.all(reconstructed.tensor == 7.0) + assert reconstructed.tensor.shape == original.tensor.shape + assert reconstructed.tensor.dtype == original.tensor.dtype + + original.drop() + + def test_to_shared_tensor_preserves_data(self): + """Test that to_shared_tensor preserves original data""" + original = SharedTensor.empty((20, 30), torch.float32) + original.tensor.normal_(0, 1) + original_data = original.tensor.clone() + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert torch.allclose(reconstructed.tensor, original_data) + + original.drop() + + def test_to_shared_tensor_shares_memory(self): + """Test that to_shared_tensor shares memory with original""" + original = SharedTensor.empty((15, 15), torch.float32) + original.tensor.fill_(5.0) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + reconstructed.tensor.fill_(10.0) + + assert torch.all(original.tensor == 10.0) + + original.drop() + + def test_to_shared_tensor_with_various_dtypes(self): + """Test to_shared_tensor works with different data types""" + for dtype in [torch.float32, torch.float64, torch.bfloat16, torch.int32]: + original = SharedTensor.empty((5, 5), dtype) + if ( + dtype == torch.bfloat16 + or dtype == torch.float32 + or dtype == torch.float64 + ): + original.tensor.normal_(0, 1) + else: + original.tensor.fill_(42) + + handle = original.get_handle() + reconstructed = handle.to_shared_tensor() + + assert reconstructed.tensor.dtype == dtype + if dtype == torch.bfloat16: + assert torch.allclose( + reconstructed.tensor.float(), original.tensor.float(), rtol=1e-3 + ) + else: + assert torch.allclose(reconstructed.tensor, original.tensor) + + original.drop() + + def test_to_shared_tensor_multiprocess(self): + """Test to_shared_tensor in multiprocess scenario""" + + def worker_process(handle, result_queue): + with handle.to_shared_tensor() as shared: + result_queue.put(shared.tensor.sum().item()) + + original = SharedTensor.empty((50, 50), torch.float32) + original.tensor.fill_(3.0) + + handle = original.get_handle() + result_queue = Queue() + + p = Process(target=worker_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + expected = 3.0 * 50 * 50 + + assert abs(result - expected) < 1e-5 + + original.drop() + + def test_to_shared_tensor_equivalent_to_constructor(self): + """Test that handle.to_shared_tensor() is equivalent to SharedTensor(handle=handle)""" + original = SharedTensor.empty((25, 25), torch.float32) + original.tensor.normal_(0, 1) + + handle = original.get_handle() + + via_method = handle.to_shared_tensor() + via_constructor = SharedTensor(handle=handle) + + assert torch.allclose(via_method.tensor, via_constructor.tensor) + assert via_method.tensor.shape == via_constructor.tensor.shape + assert via_method.tensor.dtype == via_constructor.tensor.dtype + + original.drop() + + +class TestSharedTensorBfloat16: + """Specific tests for bfloat16 support""" + + def test_bfloat16_creation(self): + """Test bfloat16 tensor creation""" + shared = SharedTensor.empty((100, 100), torch.bfloat16) + assert shared.tensor.dtype == torch.bfloat16 + shared.drop() + + def test_bfloat16_from_tensor(self): + """Test creating shared tensor from bfloat16 tensor""" + original = torch.randn(50, 50, dtype=torch.bfloat16) + shared = SharedTensor(tensor=original) + + assert shared.tensor.dtype == torch.bfloat16 + assert torch.allclose(shared.tensor.float(), original.float(), rtol=1e-3) + + shared.drop() + + def test_bfloat16_handle_preservation(self): + """Test bfloat16 dtype preserved through handle""" + shared1 = SharedTensor.empty((20, 20), torch.bfloat16) + shared1.tensor.normal_(0, 1) + + handle = shared1.get_handle() + shared2 = SharedTensor(handle=handle) + + assert shared2.tensor.dtype == torch.bfloat16 + assert torch.allclose(shared1.tensor.float(), shared2.tensor.float(), rtol=1e-3) + + shared1.drop() + + def test_bfloat16_operations(self): + """Test operations on bfloat16 tensors""" + shared = SharedTensor.empty((100, 100), torch.bfloat16) + tensor = shared.tensor + + tensor.normal_(0, 1) + mean = tensor.float().mean().item() + + # Mean should be close to 0 + assert abs(mean) < 0.1 + + shared.drop() + + +class TestSharedTensorCloseAndCleanup: + """Test explicit close() and cleanup patterns to prevent memory leaks""" + + def test_close_method(self): + """Test explicit close() releases handle and sets closed state""" + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(5.0) + + assert not shared.is_closed + + # Close should not raise + shared.close() + + assert shared.is_closed + + # Cleanup + shared._shm.unlink() + + def test_tensor_access_after_close_raises_error(self): + """Test that accessing tensor after close raises RuntimeError""" + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(5.0) + + shared.close() + + with pytest.raises(RuntimeError, match="Cannot access tensor after close"): + _ = shared.tensor + + # Cleanup + shared._shm.unlink() + + def test_get_handle_after_close_raises_error(self): + """Test that getting handle after close raises RuntimeError""" + shared = SharedTensor.empty((10, 10), torch.float32) + + shared.close() + + with pytest.raises(RuntimeError, match="Cannot get handle after close"): + shared.get_handle() + + # Cleanup + shared._shm.unlink() + + def test_is_closed_property(self): + """Test is_closed property reflects state correctly""" + shared = SharedTensor.empty((10, 10), torch.float32) + + assert not shared.is_closed + + shared.close() + + assert shared.is_closed + + # Cleanup + shared._shm.unlink() + + def test_cached_tensor_reference_becomes_invalid_after_close(self): + """Test that tensor reference obtained before close becomes invalid after close""" + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(5.0) + + # Get reference before close + tensor_ref = shared.tensor + + shared.close() + + # After close(), the memory mapping is unmapped, so even cached references + # point to invalid memory. Accessing them will cause segfault or undefined behavior. + # We can't safely test this, but we document it. + + # Accessing via shared.tensor raises error (this is what we CAN test) + with pytest.raises(RuntimeError): + _ = shared.tensor + + # Cleanup + shared._shm.unlink() + + def test_context_manager(self): + """Test context manager automatically closes""" + shm_name = None + + with SharedTensor.empty((10, 10), torch.float32) as shared: + shm_name = shared._shm_name + shared.tensor.fill_(7.0) + assert torch.all(shared.tensor == 7.0) + + # After exiting context, should be closed (but not unlinked yet) + # We need to unlink separately + from multiprocessing import shared_memory + + # Should still be able to attach (not unlinked) + shm = shared_memory.SharedMemory(name=shm_name) + shm.close() + shm.unlink() + + def test_creator_receiver_workflow(self): + """Test proper workflow: creator creates, gets handle, closes, receiver uses and closes""" + + def receiver_process(handle, result_queue): + # Receiver creates SharedTensor from handle + with SharedTensor(handle=handle) as shared: + result = shared.tensor.sum().item() + result_queue.put(result) + # Context manager auto-closes + + # Creator process + shared = SharedTensor.empty((50, 50), torch.float32) + shared.tensor.fill_(4.0) + handle = shared.get_handle() + shared.close() # Creator closes its reference + + # Pass to receiver + result_queue = Queue() + p = Process(target=receiver_process, args=(handle, result_queue)) + p.start() + p.join() + + result = result_queue.get() + assert abs(result - (4.0 * 50 * 50)) < 1e-5 + + # Unlink after all processes done + handle.drop() + + def test_handle_drop_without_creating_shared_tensor(self): + """Test that handle.drop() doesn't create unnecessary SharedTensor instance""" + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(3.0) + handle = shared.get_handle() + shared.close() + + # drop() should work without creating new SharedTensor + handle.drop() + + # Memory should be unlinked + from multiprocessing import shared_memory + + with pytest.raises(FileNotFoundError): + shared_memory.SharedMemory(name=handle.shm_name) + + def test_multiple_receivers_close_independently(self): + """Test that multiple receivers can close independently""" + + def receiver_process(handle, value, result_queue): + with SharedTensor(handle=handle) as shared: + result = shared.tensor[0, 0].item() == value + result_queue.put(result) + + # Creator + shared = SharedTensor.empty((10, 10), torch.float32) + shared.tensor.fill_(9.0) + handle = shared.get_handle() + shared.close() + + # Multiple receivers + result_queue = Queue() + processes = [] + for _ in range(3): + p = Process(target=receiver_process, args=(handle, 9.0, result_queue)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + # All should succeed + for _ in range(3): + assert result_queue.get() is True + + # Cleanup + handle.drop() + + def test_close_is_idempotent(self): + """Test that calling close() multiple times is safe""" + shared = SharedTensor.empty((10, 10), torch.float32) + + # Multiple closes should not raise + shared.close() + shared.close() + shared.close() + + # Cleanup + shared.drop() + + def test_drop_is_idempotent(self): + """Test that calling drop() multiple times is safe""" + shared = SharedTensor.empty((10, 10), torch.float32) + handle = shared.get_handle() + shared.close() + + # Multiple drops should not raise + handle.drop() + handle.drop() + handle.drop() + + def test_proper_cleanup_prevents_leak(self): + """Test that proper close + unlink pattern doesn't leak""" + import glob + + # Get initial shared memory count + shm_before = len(glob.glob("/dev/shm/shared_tensor_*")) + + # Create and properly cleanup 10 shared tensors + for _ in range(10): + shared = SharedTensor.empty((100, 100), torch.float32) + handle = shared.get_handle() + shared.close() + handle.drop() + + # Check no leaks + shm_after = len(glob.glob("/dev/shm/shared_tensor_*")) + assert ( + shm_after == shm_before + ), f"Memory leak detected: {shm_after - shm_before} tensors leaked" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])