diff --git a/README.md b/README.md index c77e7cf5..21634c6b 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ An e2e framework for creating, deploying and using isolated execution environmen [![Discord](https://img.shields.io/badge/Discord-OpenEnv-7289da?style=flat&logo=discord&logoColor=white)](https://discord.gg/YsTYBh6PD9) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-pytorch/OpenEnv/blob/main/examples/OpenEnv_Tutorial.ipynb) **← Try the Interactive Tutorial!** +--- + +**šŸš€ Featured Example:** Train LLMs to play BlackJack using [torchforge](https://github.com/meta-pytorch/torchforge) (PyTorch's agentic RL framework): [`examples/grpo_blackjack/`](examples/grpo_blackjack/) + ## OpenEnv on partner platforms: - [Lightning AI Studio](https://lightning.ai/environments?section=featured) @@ -178,10 +182,10 @@ client.close() # Stops and removes container - smolagents (for coding environment) ## Supported RL Tools -The goal of this project is to support a broad set of open and closed tools to help standardize the agentic RL community. If you have a project that supports OpenEnv environments, please put up a PR to add your tool name along with a link to your documentation. +The goal of this project is to support a broad set of open and closed tools to help standardize the agentic RL community. If you have a project that supports OpenEnv environments, please put up a PR to add your tool name along with a link to your documentation. ### torchforge -(coming soon) +See GRPO BlackJack training example: [`examples/grpo_blackjack/`](examples/grpo_blackjack/) ### TRL (coming soon} diff --git a/examples/grpo_blackjack/README.md b/examples/grpo_blackjack/README.md new file mode 100644 index 00000000..8d141607 --- /dev/null +++ b/examples/grpo_blackjack/README.md @@ -0,0 +1,191 @@ +# Training LLMs to Play BlackJack with GRPO + OpenEnv + +This example demonstrates how to train language models to play BlackJack using **GRPO (Group Relative Policy Optimization)** and **OpenEnv**. + +## šŸŽÆ What This Example Shows + +- **OpenEnv**: Universal RL environment interface for 70+ environments +- **GRPO**: Efficient RL algorithm (used by DeepSeek R1) that only needs 2 models instead of 3 +- **Forge**: PyTorch-native agentic RL library for production training +- **End-to-End Training**: From random policy (~35% win rate) to trained agent + +## šŸ“ Files + +- `grpo_blackjack_tutorial.ipynb` - Interactive tutorial notebook (recommended starting point) +- `grpo_utils.py` - Production GRPO utilities and helper functions +- `blackjack.yaml` - Training configuration file +- `README.md` - This file + +## šŸš€ Quick Start + +### Prerequisites + +1. **Install OpenEnv**: + ```bash + # Clone OpenEnv repo + git clone https://github.com/meta-pytorch/OpenEnv.git + cd OpenEnv + pip install -e . + ``` + +2. **Install Forge** (PyTorch's agentic RL library): + ```bash + git clone https://github.com/meta-pytorch/torchforge.git + cd torchforge + pip install -e . + ``` + +3. **Start OpenEnv BlackJack Server**: + ```bash + # In a separate terminal + export OPENENV_PATH="/path/to/OpenEnv/src" + export PYTHONPATH="${OPENENV_PATH}:${PYTHONPATH}" + + OPENSPIEL_GAME=blackjack python -m envs.openspiel_env.server.app --port 8004 + ``` + +### Run the Tutorial + +Open the Jupyter notebook: +```bash +jupyter notebook grpo_blackjack_tutorial.ipynb +``` + +Follow the cells to: +1. **Explore OpenEnv** - Connect to BlackJack environment +2. **Benchmark baseline** - Test random policy performance +3. **Learn about GRPO** - Understand the training algorithm +4. **Train with Forge** - Run production GRPO training +5. **Switch environments** - See how to train on other games + +## šŸ“š What You'll Learn + +### OpenEnv: Universal RL Environment Spec + +OpenEnv is **not a game engine** - it's a **specification** that wraps ANY RL environment: + +```python +# Same interface works for 70+ environments +result = env.reset() # Start episode +result = env.step(action) # Take action +state = env.state() # Get state +env.close() # Cleanup +``` + +Change one environment variable → train on different games! + +### Forge: PyTorch-Native Agentic RL + +Forge handles all distributed systems complexity: +- **Generator (vLLM)**: Fast LLM inference +- **RLTrainer**: Distributed training with FSDP +- **ReplayBuffer**: Off-policy learning +- **ReferenceModel**: KL penalty computation +- **Torchstore**: Distributed weight management + +You just write: +```python +trainer = await setup_forge_training("blackjack.yaml") +await trainer.run(steps=100) +``` + +Everything else is automated! + +## šŸŽ“ Educational Resources + +This tutorial is inspired by the excellent [Unsloth RL Guide](https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide). We highly recommend reading it for deeper insights! + +### Further Reading + +- **OpenEnv**: [GitHub](https://github.com/meta-pytorch/OpenEnv) +- **GRPO Paper**: [arXiv:2402.03300](https://arxiv.org/abs/2402.03300) +- **Forge**: [GitHub](https://github.com/meta-pytorch/torchforge) | [Docs](https://meta-pytorch.org/torchforge/) +- **Unsloth RL Guide**: [docs.unsloth.ai](https://docs.unsloth.ai/get-started/reinforcement-learning-rl-guide) + +## šŸ’” Key Concepts + +### "Patience Is All You Need" for RL + +RL works by patience: if the correct answer has *any* non-zero probability, we'll eventually find it through sampling. While waiting: +1. Learn from **bad answers** → decrease their probability +2. When finding **good answers** → increase their probability + +Over time, the model learns not just *what* to do, but *why* (reasoning process). + +### Reward Functions + +Reward functions tell the model what's good/bad. For BlackJack: + +```python +def evaluate_response(prompt, response, game_reward): + reward = float(game_reward) # +1 (win), -1 (loss), 0 (push) + + # Reward shaping + if game_reward > 0: + reward = 2.0 # Wins more valuable + elif game_reward == 0: + reward = 0.5 # Pushes better than losses + + return reward +``` + +The key: **Reward functions must be verifiable**. You can verify "is the answer correct?" but not "is this creative?" + +## šŸ”„ Switching to Other Games + +The beauty of OpenEnv: **same code works for any environment!** + +### Try Tic-Tac-Toe +```bash +OPENSPIEL_GAME=tic_tac_toe python -m envs.openspiel_env.server.app --port 8005 +``` +Update config: `server_url = "http://localhost:8005"` + +### Try Chess +```bash +OPENSPIEL_GAME=chess python -m envs.openspiel_env.server.app --port 8006 +``` + +### Try Atari +```bash +python -m envs.atari_env.server.app --game pong --port 8007 +``` + +Everything else stays the same! Same GRPO code, same Forge infrastructure. + +## šŸ› ļø Customization + +All code is in `grpo_utils.py`: +- Modify `BlackJackReward.evaluate_response()` for reward shaping +- Adjust `ComputeAdvantages.compute()` for advantage computation +- Tweak `simple_grpo_loss()` for KL penalty (beta parameter) +- Change `format_prompt()` for different prompt templates + +Edit `blackjack.yaml` for: +- Different model sizes (1B to 70B+) +- More training steps +- Larger group sizes +- Parallel rollout collection + +## šŸ“Š Expected Results + +- **Random policy**: ~35% win rate +- **After GRPO training**: Improves toward optimal BlackJack strategy (~43% win rate) +- **Training time**: Varies based on model size and training steps + +The model learns both strategy AND reasoning process (similar to DeepSeek R1's `` tokens). + +## šŸ¤ Credits + +- **OpenEnv**: Meta PyTorch team +- **Forge**: Meta PyTorch team +- **GRPO**: DeepSeek research team +- **Tutorial inspiration**: Unsloth team + +## šŸ“ License + +This example follows the same license as the parent OpenEnv repository. + +## šŸ™ Acknowledgments + +Big thanks to the **Unsloth team** for their educational approach to RL! This tutorial's GRPO section is heavily inspired by their excellent guide. diff --git a/examples/grpo_blackjack/blackjack.yaml b/examples/grpo_blackjack/blackjack.yaml new file mode 100644 index 00000000..dcf4c690 --- /dev/null +++ b/examples/grpo_blackjack/blackjack.yaml @@ -0,0 +1,155 @@ +# BlackJack GRPO Training Configuration +# >>> python -m apps.grpo.blackjack_main --config apps/grpo/blackjack.yaml +# +# Prerequisites: +# 1. Start BlackJack server: +# cd /Users/sanyambhutani/OpenEnv/OpenEnv +# export PYTHONPATH="/Users/sanyambhutani/OpenEnv/OpenEnv/src:${PYTHONPATH}" +# OPENSPIEL_GAME=blackjack python -m envs.openspiel_env.server.app +# +# 2. Run training: +# python -m apps.grpo.blackjack_main --config apps/grpo/blackjack.yaml + +# Global configuration +group_size: 4 # Number of parallel games per rollout +local_batch_size: 8 # Per-device batch size +max_req_tokens: 512 # Max tokens for prompt (BlackJack prompts are ~200-300 tokens) +max_res_tokens: 32 # Max tokens for response (just "HIT" or "STAND" + thinking) +model: "Qwen/Qwen3-1.7B" +off_by_n: 1 # Off-policy tolerance + +# Main loop configuration +rollout_threads: 1 # Number of parallel rollout threads + +# Observability configuration +metric_logging: + wandb: + project: "blackjack-grpo-tutorial" + group: "blackjack_exp_${oc.env:USER}" + reduce_across_ranks: True + console: + reduce_across_ranks: True + +# BlackJack environment configuration +blackjack_env: + server_url: "http://localhost:8004" + model: ${model} + +# Policy configuration (generator) +policy: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs + model: ${model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams + n: 1 # Generate 1 response per game state (not group_size, since we play full games) + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${local_batch_size} + seq_len: 1024 # Shorter than GSM8K since BlackJack episodes are shorter + max_norm: 1.0 + steps: 1000 # Tutorial: 1000 steps (increase for production) + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Replay buffer configuration +replay_buffer: + batch_size: ${local_batch_size} + max_policy_age: ${off_by_n} + dp_size: ${trainer.parallelism.data_parallel_shard_degree} + +# Reference model configuration +ref_model: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + +# All resource allocations +services: + policy: + procs: ${policy.engine_args.tensor_parallel_size} + num_replicas: 1 + mesh_name: policy + with_gpus: true + ref_model: + procs: 1 + num_replicas: 1 + mesh_name: ref_model + with_gpus: true + reward_actor: + procs: 1 + num_replicas: 1 + mesh_name: reward_actor + with_gpus: false + +actors: + blackjack_env: + procs: 1 + with_gpus: false + mesh_name: blackjack_env + trainer: + procs: 1 + with_gpus: true + mesh_name: trainer + replay_buffer: + procs: 1 + with_gpus: false + mesh_name: replay_buffer + compute_advantages: + procs: 1 + with_gpus: false + mesh_name: compute_advantages diff --git a/examples/grpo_blackjack/grpo_blackjack_tutorial.ipynb b/examples/grpo_blackjack/grpo_blackjack_tutorial.ipynb new file mode 100644 index 00000000..ddd0f9ea --- /dev/null +++ b/examples/grpo_blackjack/grpo_blackjack_tutorial.ipynb @@ -0,0 +1,497 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training LLMs in ANY Environment with OpenEnv\n", + "\n", + "## \ud83c\udfaf The Vision\n", + "\n", + "Imagine training language models in:\n", + "- \ud83c\udfb0 **Card games** (BlackJack, Poker, Uno)\n", + "- \u265f\ufe0f **Board games** (Chess, Go, Connect Four)\n", + "- \ud83d\udcc8 **Trading simulations** (realistic market environments)\n", + "- \ud83c\udfae **Atari games** (Pong, Breakout, Space Invaders)\n", + "- \ud83d\udcbb **Code execution environments** (interactive debugging)\n", + "- \ud83e\udd16 **Robotics simulations** (MuJoCo, PyBullet)\n", + "\n", + "---\n", + "\n", + "### The Problem\n", + "\n", + "Every RL environment has different APIs:\n", + "- \u274c OpenSpiel uses C++ bindings\n", + "- \u274c Atari needs ALE (Arcade Learning Environment)\n", + "- \u274c Trading sims have custom interfaces\n", + "- \u274c Each requires different dependencies, versions, OS compatibility\n", + "- \u274c No isolation \u2192 crashes can corrupt your system\n", + "\n", + "**You spend more time wrestling with environments than training models.**\n", + "\n", + "---\n", + "\n", + "### The Solution: OpenEnv - A Universal Spec\n", + "\n", + "
\n", + "

\ud83d\ude80 OpenEnv = Universal RL Environment Interface

\n", + "

\n", + " OpenEnv is not a game engine.
\n", + " It's a specification that wraps ANY RL environment with a clean, unified API.\n", + "

\n", + " \n", + "

\n", + " One interface. Any environment. Zero setup.\n", + "

\n", + "
\n", + "\n", + "---\n", + "\n", + "## What You'll Build\n", + "\n", + "In this tutorial, you'll:\n", + "1. \ud83d\udd0c **Explore OpenEnv** - Connect to BlackJack, see how the spec works\n", + "2. \ud83c\udfb2 **Benchmark policies** - Test random vs heuristic strategies\n", + "3. \ud83e\udde0 **Learn about GRPO** - Brief intro to the training algorithm\n", + "4. \u26a1 **Train with Forge** - Use PyTorch's agentic RL library\n", + "5. \ud83d\udcca **Compare results** - Measure improvement\n", + "6. \ud83d\udd04 **Switch environments** - Show how to train on different games\n", + "\n", + "**This uses production code.** Same implementation as `apps/grpo/blackjack_main_fixed.py`.\n", + "\n", + "---\n", + "\n", + "### \ud83d\udcda Resources\n", + "- \ud83d\udce6 [OpenEnv GitHub](https://github.com/meta-pytorch/OpenEnv) - Universal RL environment spec\n", + "- \ud83d\udcc4 [GRPO Paper (arXiv:2402.03300)](https://arxiv.org/abs/2402.03300) - Group Relative Policy Optimization\n", + "- \ud83d\udd27 [Forge GitHub](https://github.com/meta-pytorch/torchforge) - PyTorch-native agentic RL library\n", + "- \ud83d\udcd6 [Forge Docs](https://meta-pytorch.org/torchforge/) - Full documentation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## \ud83d\udd0c Part 1: Exploring OpenEnv\n\nLet's connect to a BlackJack environment and explore the OpenEnv spec.\n\n### Start the Server\n\n
\n \u26a0\ufe0f Note: Start the OpenEnv server in a separate terminal:\n
\n# Set your OpenEnv path\nexport OPENENV_PATH=\"/path/to/OpenEnv/src\"\nexport PYTHONPATH=\"${OPENENV_PATH}:${PYTHONPATH}\"\n\n# Start BlackJack server\nOPENSPIEL_GAME=blackjack python -m envs.openspiel_env.server.app --port 8004
\n
" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Environment setup for Jupyter\n", + "import sys\n", + "import os\n", + "\n", + "# Fix for Monarch/Torchstore Rust bindings in Jupyter\n", + "conda_prefix = os.environ.get('CONDA_PREFIX', sys.prefix)\n", + "lib_path = f\"{conda_prefix}/lib\"\n", + "\n", + "if 'LD_LIBRARY_PATH' in os.environ:\n", + " if lib_path not in os.environ['LD_LIBRARY_PATH']:\n", + " os.environ['LD_LIBRARY_PATH'] = f\"{lib_path}:{os.environ['LD_LIBRARY_PATH']}\"\n", + "else:\n", + " os.environ['LD_LIBRARY_PATH'] = lib_path\n", + "\n", + "print(\"\u2705 Environment configured\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Connect to OpenEnv\n", + "\n", + "Let's connect to the BlackJack environment and explore its interface." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys", + "import os", + "from pathlib import Path", + "", + "# Add OpenEnv to path (update this to your OpenEnv installation)", + "openenv_path = os.environ.get('OPENENV_PATH', '/path/to/OpenEnv/src')", + "if openenv_path not in sys.path:", + " sys.path.insert(0, openenv_path)", + "", + "from envs.openspiel_env import OpenSpielEnv, OpenSpielAction", + "from grpo_utils import show_openenv_observation", + "", + "# Connect to environment", + "env = OpenSpielEnv(base_url=\"http://localhost:8004\")", + "", + "print(\"\ud83c\udfb0 Connected to BlackJack environment\")", + "print(\"\\n\ud83d\udccd Resetting environment...\\n\")", + "", + "# Reset and observe", + "result = env.reset()", + "show_openenv_observation(result.observation)", + "", + "env.close()", + "print(\"\\n\u2705 OpenEnv interface exploration complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### What Just Happened?\n", + "\n", + "You just saw the **OpenEnv spec** in action:\n", + "\n", + "```python\n", + "# Universal interface - works for ANY environment\n", + "result = env.reset() # Start episode\n", + "result = env.step(action) # Take action\n", + "state = env.state() # Get environment state\n", + "env.close() # Cleanup\n", + "```\n", + "\n", + "**Key observations:**\n", + "- `legal_actions`: What actions the agent can take\n", + "- `info_state`: Numeric observation vector\n", + "- `game_phase`: Current phase of the game\n", + "- `reward`: Outcome (+1 win, -1 loss, 0 push)\n", + "\n", + "This same interface works for **70+ different environments**. Change the server, everything else stays the same!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## \ud83c\udfb2 Part 2: Benchmarking Baseline Policies\n", + "\n", + "Before training an LLM, let's see how simple policies perform." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from grpo_utils import play_random_policy", + "", + "print(\"\ud83c\udfb2 Running random policy baseline...\\n\")", + "", + "# Play 100 games with random actions", + "stats = play_random_policy(\"http://localhost:8004\", num_games=100)", + "", + "print(\"\\n\ud83d\udcca Random Policy Results:\")", + "print(f\" Games played: {stats['total_games']}\")", + "print(f\" Wins: {stats['wins']}\")", + "print(f\" Losses: {stats['losses']}\")", + "print(f\" Pushes: {stats['pushes']}\")", + "print(f\" Win rate: {stats['win_rate']:.1%}\")", + "print(\"\\n\ud83d\udcdd Note: Optimal BlackJack strategy achieves ~43% win rate\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### The Challenge\n", + "\n", + "Random policy performs poorly (~30-35% win rate).\n", + "\n", + "**Can we train an LLM to do better?**\n", + "\n", + "That's where **GRPO** comes in." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## \ud83e\udde0 Part 3: Understanding Reinforcement Learning & GRPO\n\n
\n

\ud83d\udcda Section Inspired by Unsloth

\n

\n This section is heavily inspired by the excellent Unsloth RL Guide.\n

\n Unsloth has done an amazing job making RL accessible and intuitive. We highly recommend reading their full guide for deeper insights and practical tips!\n

\n \ud83d\ude4f Big thanks to the Unsloth team for their educational approach to RL.\n

\n
\n\n---\n\n### What is Reinforcement Learning?\n\n
\n

The Core Idea (It's Simpler Than You Think!)

\n

\n The goal of RL is extremely simple:\n

\n \n

\n That's it! Everything else is just details about what \"good\" and \"bad\" mean, and how to increase/decrease their probabilities.\n

\n
\n\n#### A Simple Example: Learning \"2 + 2 = ?\"\n\nImagine an untrained language model trying to answer \"What is 2+2?\". It might output:\n\n```\n0, cat, -10, 1928, 3, A, B, 122, 17, 182, 172, A, C, BAHS, %$, #, 9, -192, 12.31, ...\n```\n\nThen suddenly: **4** \u2713\n\nThe reward signals would be:\n```\n0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... then 1\n```\n\n**This is the key insight:** By patience (or \"luck\"), if the correct answer has *any* non-zero probability, RL will eventually find it. The trick is:\n1. While waiting, we learn from **bad answers** \u2192 tell model \"don't do this\"\n2. When we find **good answers** \u2192 tell model \"do more of this\"\n\nThis is why I like to call it **\"Patience Is All You Need\"** for RL.\n\n---\n\n### From PPO to GRPO: The Evolution\n\n
\n

\ud83d\udcdc The Algorithm Evolution

\n \n\n\n \n \n\n\n \n \n\n
RLHF + PPO (OpenAI ChatGPT)Needed 3 models: Policy, Reference, Value Model
GRPO (DeepSeek R1)Only needs 2 models: Policy + Reference
\u2192 Much more efficient!
\n
\n\n**What GRPO removes:**\n- \u274c **Value Model** \u2192 Replaced with group statistics\n- \u274c **Reward Model** \u2192 Replaced with simple reward functions\n\n**Why this matters:**\n- \ud83d\udcbe Less memory usage\n- \u26a1 Faster training\n- \ud83c\udfaf Easier to implement\n\n---\n\n### GRPO: Group Relative Policy Optimization\n\n
\n

Why \"Group Relative\"?

\n

\n Instead of training a separate Value Model to estimate \"how good is this state?\", \n GRPO uses a clever trick: sample the model multiple times and compare answers within the group.\n

\n
\n\n**Example: Training on \"What is 2+2?\"**\n\n1. **Generate multiple responses** (e.g., 4 samples):\n - Response 1: \"4\" \u2192 reward = +1 (correct!)\n - Response 2: \"3\" \u2192 reward = 0 (close, but wrong)\n - Response 3: \"D\" \u2192 reward = -1 (nonsense)\n - Response 4: \"C\" \u2192 reward = -1 (nonsense)\n\n2. **Calculate group statistics:**\n - Mean reward: (-1 + -1 + 0 + 1) / 4 = -0.25\n - Standard deviation: ~0.83\n\n3. **Compute advantages** (Z-score normalization):\n - Response 1: +1.5 (much better than average!)\n - Response 2: +0.3 (slightly better)\n - Response 3: -0.9 (worse than average)\n - Response 4: -0.9 (worse than average)\n\n4. **Update model:**\n - Increase probability of generating \"4\"\n - Slightly increase \"3\" (it's closer than nonsense)\n - Decrease probability of generating \"D\" and \"C\"\n\nThis is **group-relative** because we're comparing within the group, not to an absolute baseline!\n\n---\n\n### Reward Functions: The Secret Sauce\n\nReward functions tell the model what's \"good\" and what's \"bad\". They can be simple or complex:\n\n**For BlackJack (what we're using):**\n```python\ndef evaluate_response(prompt, response, game_reward):\n reward = float(game_reward) # +1 (win), -1 (loss), 0 (push)\n \n # Reward shaping: Scale up wins\n if game_reward > 0:\n reward = 2.0 # Wins are more valuable\n elif game_reward == 0:\n reward = 0.5 # Pushes better than losses\n \n return reward\n```\n\n**For Math Problems:**\n- If answer is a number: +1\n- If answer matches ground truth: +3\n- If no number detected: -1\n- **Total reward:** Sum of all criteria\n\n**For Email Automation:**\n- Contains required keyword: +1\n- Matches ideal response: +1\n- Too long: -1\n- Includes recipient name: +1\n- Has signature block: +1\n\nThe key is: **Reward functions must be verifiable**. You can't subjectively judge \"is this creative?\" but you can verify \"is this answer correct?\"\n\n---\n\n### The Training Process (Simplified)\n\n```\n1. Play game \u2192 Get action \"HIT\" or \"STAND\"\n \u2193\n2. Game ends \u2192 Observe reward (+1 win, -1 loss, 0 push)\n \u2193\n3. Repeat 4-8 times for the same question (group)\n \u2193\n4. Calculate group statistics (mean, std)\n \u2193\n5. Compute advantages (which answers were better/worse than average?)\n \u2193\n6. Update model: increase good action probability, decrease bad\n \u2193\n7. Repeat thousands of times \u2192 Model learns strategy!\n```\n\n**Key insight:** Over time, the model learns not just \"what to do\" but also *why* (the reasoning process). This is how DeepSeek R1 developed its famous `` tokens!\n\n---\n\n### Forge: PyTorch-Native Agentic RL Infrastructure\n\n
\n

What is Forge?

\n

\n Forge is PyTorch's official library for training agentic RL models. It handles all the distributed systems complexity so you can focus on algorithms.\n

\n
    \n
  • Generator (vLLM): Fast LLM inference with automatic batching
  • \n
  • RLTrainer: Distributed training with FSDP across GPUs
  • \n
  • ReplayBuffer: Stores episodes for off-policy learning
  • \n
  • ReferenceModel: Keeps original model for KL penalty
  • \n
  • Torchstore: Distributed weight management across replicas
  • \n
\n
\n\n**Resources:**\n- \ud83d\udd27 [GitHub](https://github.com/meta-pytorch/torchforge) - Source code\n- \ud83d\udcd6 [Documentation](https://meta-pytorch.org/torchforge/) - Full docs\n- \ud83d\udcc4 [GRPO Paper](https://arxiv.org/abs/2402.03300) - Original research\n\n**In this tutorial:** We abstract all of Forge's complexity. You just call:\n```python\ntrainer = await setup_forge_training(\"config.yaml\")\nawait trainer.run(steps=100)\n```\n\nEverything else happens automatically! \ud83d\ude80" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## \ud83c\udfd7\ufe0f Part 4: Training with GRPO\n", + "\n", + "Now let's train a Qwen 1.5B model to play BlackJack using production GRPO code.\n", + "\n", + "### Architecture Overview\n", + "\n", + "```\n", + "\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", + "\u2503 YOUR TRAINING LOOP \u2503\n", + "\u2523\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u252b\n", + "\u2503 \u2503\n", + "\u2503 Rollouts Loop Training Loop \u2503\n", + "\u2503 \u2022 Play games \u2022 Sample batch \u2503\n", + "\u2503 \u2022 Collect episodes \u2022 Compute loss \u2503\n", + "\u2503 \u2022 Compute advantages \u2022 Update weights \u2503\n", + "\u2503 \u2022 Add to buffer \u2022 Push to replicas \u2503\n", + "\u2503 \u2503\n", + "\u2517\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u252f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u252f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u251b\n", + " \u2502 \u2502\n", + " HTTP \u2502 \u2502 RPC\n", + " \u2502 \u2502\n", + " \u2193 \u2193\n", + " \u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513 \u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", + " \u2503 OpenEnv \u2503 \u2503 Forge \u2503\n", + " \u2503 Server \u2503 \u2503 Services \u2503\n", + " \u2517\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u251b \u2517\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u251b\n", + "```\n", + "\n", + "**Two concurrent loops:**\n", + "1. **Rollouts:** Play games via OpenEnv \u2192 collect episodes\n", + "2. **Training:** Sample from buffer \u2192 update policy with GRPO\n", + "\n", + "They run in parallel for maximum efficiency!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup and Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from grpo_utils import setup_forge_trainingprint(\"\ud83c\udfd7\ufe0f Initializing Forge infrastructure...\\n\")print(\"This will:\")print(\" \u2022 Load the Qwen 1.5B model\")print(\" \u2022 Initialize vLLM inference servers\")print(\" \u2022 Setup distributed training (TorchTitan)\")print(\" \u2022 Create replay buffer and reference model\")print(\"\\n\u23f3 This may take 1-2 minutes...\\n\")# Initialize everything with one function calltrainer = await setup_forge_training(\"blackjack.yaml\")print(\"\\n\u2705 Ready to train!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run Training\n", + "\n", + "Now we train for 100 steps. This is a shortened demo - production training uses 1000+ steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\ud83d\ude80 Starting GRPO training!\\n\")\n", + "print(\"Watch the logs to see:\")\n", + "print(\" \u2022 Games being played (with actions and outcomes)\")\n", + "print(\" \u2022 Win rate improving over time\")\n", + "print(\" \u2022 Training steps updating the policy\")\n", + "print(\"\\n\" + \"=\"*60 + \"\\n\")\n", + "\n", + "# Run training (this is the production training loop!)\n", + "results = await trainer.run(steps=100)\n", + "\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"\\n\ud83c\udf89 Training complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Shutdown Forge services\n", + "await trainer.shutdown()\n", + "print(\"\u2705 Shutdown complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## \ud83d\udd04 Part 5: The Power of OpenEnv - Switching Environments\n", + "\n", + "Here's the magic: **The same code works for ANY OpenEnv environment.**\n", + "\n", + "### Switch to Tic-Tac-Toe\n", + "\n", + "Just change the server:\n", + "\n", + "```bash\n", + "# Terminal:\n", + "OPENSPIEL_GAME=tic_tac_toe python -m envs.openspiel_env.server.app --port 8005\n", + "```\n", + "\n", + "Update config:\n", + "```python\n", + "cfg.blackjack_env.server_url = \"http://localhost:8005\"\n", + "```\n", + "\n", + "**Everything else stays identical.** Same GRPO code, same Forge infrastructure.\n", + "\n", + "---\n", + "\n", + "### Switch to Chess\n", + "\n", + "```bash\n", + "OPENSPIEL_GAME=chess python -m envs.openspiel_env.server.app --port 8006\n", + "```\n", + "\n", + "Update model and config for longer sequences, done!\n", + "\n", + "---\n", + "\n", + "### Switch to Atari\n", + "\n", + "```bash\n", + "# Different OpenEnv backend\n", + "python -m envs.atari_env.server.app --game pong --port 8007\n", + "```\n", + "\n", + "Modify prompt formatting for vision inputs, same training loop!\n", + "\n", + "---\n", + "\n", + "
\n", + "

\ud83d\udca1 The Key Insight

\n", + "

\n", + " OpenEnv is a spec, not a game engine.

\n", + " Once you have a training loop that talks to OpenEnv, you can train on ANY environment that implements the spec.\n", + "

\n", + " Change one environment variable \u2192 train on 70+ different environments.\n", + "

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## \ud83d\ude80 Next Steps\n", + "\n", + "### 1. Scale Up Training\n", + "\n", + "Edit `apps/grpo/blackjack.yaml`:\n", + "\n", + "```yaml\n", + "trainer:\n", + " training:\n", + " steps: 1000 # More training steps\n", + "\n", + "group_size: 8 # More games per rollout\n", + "rollout_threads: 4 # Parallel rollout collection\n", + "```\n", + "\n", + "Run from command line for serious training:\n", + "\n", + "```bash\n", + "python -m apps.grpo.blackjack_main_fixed --config apps/grpo/blackjack.yaml\n", + "```\n", + "\n", + "---\n", + "\n", + "### 2. Explore Other Environments\n", + "\n", + "Try different OpenSpiel games:\n", + "- `OPENSPIEL_GAME=tic_tac_toe`\n", + "- `OPENSPIEL_GAME=connect_four`\n", + "- `OPENSPIEL_GAME=go`\n", + "\n", + "Explore other OpenEnv backends:\n", + "- Atari environments\n", + "- FinRL trading simulations\n", + "- Custom environments\n", + "\n", + "---\n", + "\n", + "### 3. Customize the Training\n", + "\n", + "All the code is in `apps/grpo/grpo_utils.py`:\n", + "- Modify reward shaping in `BlackJackReward.evaluate_response()`\n", + "- Adjust advantage computation in `ComputeAdvantages.compute()`\n", + "- Tweak GRPO loss hyperparameters (beta, KL penalty)\n", + "- Change prompt formatting in `format_prompt()`\n", + "\n", + "---\n", + "\n", + "## \ud83d\udcda Resources\n", + "\n", + "### OpenEnv\n", + "- \ud83d\udce6 [GitHub](https://github.com/meta-pytorch/OpenEnv) - Source code and examples\n", + "- \ud83d\udcd6 [Spec Documentation](https://github.com/meta-pytorch/OpenEnv#spec) - Full API reference\n", + "\n", + "### GRPO\n", + "- \ud83d\udcc4 [Paper (arXiv:2402.03300)](https://arxiv.org/abs/2402.03300) - Original publication\n", + "- \ud83d\udd2c [Blog Post](https://ai.meta.com/blog/grpo/) - High-level explanation\n", + "\n", + "### Forge\n", + "- \ud83d\udd27 [GitHub](https://github.com/meta-pytorch/torchforge) - PyTorch-native agentic RL\n", + "- \ud83d\udcd6 [Docs](https://meta-pytorch.org/torchforge/) - Full documentation\n", + "- \ud83d\udcac [Discussions](https://github.com/meta-pytorch/torchforge/discussions) - Community support\n", + "\n", + "---\n", + "\n", + "## \ud83c\udf93 Key Takeaways\n", + "\n", + "
\n", + "

What You Learned

\n", + "
    \n", + "
  1. OpenEnv is a universal spec for RL environments - not just games, ANY interactive environment.
  2. \n", + "
  3. One training loop works everywhere - switch environments by changing a URL.
  4. \n", + "
  5. Forge abstracts distributed RL complexity - focus on algorithms, not infrastructure.
  6. \n", + "
  7. GRPO enables stable LLM training - group-relative advantages + KL penalties work.
  8. \n", + "
  9. Production code is accessible - this notebook uses the same code as large-scale training.
  10. \n", + "
\n", + "
\n", + "\n", + "---\n", + "\n", + "
\n", + "

\ud83c\udf89 Congratulations!

\n", + "

\n", + " You just trained an LLM using production GRPO code.
\n", + " You explored OpenEnv as a universal RL interface.
\n", + " You saw how Forge abstracts distributed training complexity.\n", + "

\n", + "

\n", + " Now go train agents in ANY environment! \ud83d\ude80\n", + "

\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/examples/grpo_blackjack/grpo_utils.py b/examples/grpo_blackjack/grpo_utils.py new file mode 100644 index 00000000..35ee4649 --- /dev/null +++ b/examples/grpo_blackjack/grpo_utils.py @@ -0,0 +1,852 @@ +""" +GRPO Utilities for OpenEnv Training + +This module contains reusable components extracted from the production GRPO implementation. +Used by both the tutorial notebook and the full training script. +""" + +import asyncio +import time +import uuid +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +import torchstore as ts +from omegaConf import DictConfig + +from envs.openspiel_env import OpenSpielAction, OpenSpielEnv +from forge.actors._torchstore_utils import ( + get_dcp_whole_state_dict_key, + get_param_prefix, +) +from forge.actors.generator import Generator +from forge.actors.reference_model import ReferenceModel +from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import RLTrainer +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import init_provisioner, shutdown +from forge.data_models.completion import Completion +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import Reduce, record_metric +from forge.observability.perf_tracker import Tracer +from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.ops import compute_logprobs +from monarch.actor import endpoint +from vllm.transformers_utils.tokenizer import get_tokenizer + + +# ============================================================================ +# Data Structures +# ============================================================================ + + +@dataclass +class Episode: + """Episode data for RL training.""" + + episode_id: str + pad_id: int + request_len: int + response_len: int + game_id: str + step_in_game: int + completion: Completion | None = None + ref_logprobs: torch.Tensor | None = None + reward: float | None = None + advantage: float | None = None + + @property + def policy_version(self) -> int | None: + return self.completion.generator_version + + @property + def request_tensor(self) -> torch.Tensor: + request_tokens: torch.Tensor = self.completion.prompt_ids + tensor = torch.tensor(request_tokens, dtype=torch.long) + if tensor.shape[0] < self.request_len: + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self) -> torch.Tensor: + response_tokens: torch.Tensor = self.completion.token_ids + tensor = torch.tensor(response_tokens, dtype=torch.long) + if tensor.shape[0] < self.response_len: + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor + + +Group = list[Episode] + + +# ============================================================================ +# GRPO Loss and Collation +# ============================================================================ + + +def collate(batches: list[Group]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Collate batches of episodes into model inputs and targets. + + Args: + batches: List of episode groups + + Returns: + Tuple of (inputs, targets) for training + """ + inputs = [] + targets = [] + for batch in batches: + request = torch.stack([e.request_tensor for e in batch]) + response = torch.stack([e.response_tensor for e in batch]) + ref_logprobs = torch.stack([e.ref_logprobs for e in batch]).squeeze() + advantages = torch.tensor([e.advantage for e in batch]).unsqueeze(-1) + pad_id = batch[0].pad_id + mask = response != pad_id + + input = {"tokens": torch.cat([request, response], dim=1)} + target = { + "response": response, + "ref_logprobs": ref_logprobs, + "advantages": advantages, + "padding_mask": mask, + } + inputs.append(input) + targets.append(target) + return inputs, targets + + +def simple_grpo_loss( + logits: torch.Tensor, + response: torch.Tensor, + ref_logprobs: torch.Tensor, + advantages: torch.Tensor, + padding_mask: torch.Tensor, + beta: float = 0.1, +) -> torch.Tensor: + """ + GRPO loss with KL penalty. + + Args: + logits: Model logits + response: Response tokens + ref_logprobs: Reference model log probabilities + advantages: Normalized advantages (group-relative) + padding_mask: Mask for padded tokens + beta: KL penalty coefficient + + Returns: + Scalar loss value + """ + logprobs: torch.Tensor = compute_logprobs(logits, response) + + # KL divergence: KL(ref || policy) in closed form + kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 + + # Policy gradient term with importance weight + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + + # Combined loss: maximize policy improvement, minimize KL + per_token_loss = -(per_token_policy_loss - beta * kl) + + # Average over valid tokens + loss = ( + ((per_token_loss * padding_mask).sum(dim=1)) + / (padding_mask.sum(dim=1).clamp(min=1.0)) + ).mean() + + return loss + + +# ============================================================================ +# Prompt Formatting and Action Parsing +# ============================================================================ + + +def format_prompt(step_num: int, action_history: list, tokenizer) -> str: + """ + Format game state as text prompt for LLM. + + Args: + step_num: Current step number in game + action_history: List of (action_id, action_name) tuples + tokenizer: HuggingFace tokenizer with chat template + + Returns: + Formatted prompt string + """ + system = """You are an expert BlackJack player. Output only 'HIT' or 'STAND'.""" + + state_desc = f"=== BlackJack Game (Step {step_num + 1}) ===\n\n" + if action_history: + state_desc += "Previous actions:\n" + for i, (_, name) in enumerate(action_history): + state_desc += f" {i + 1}. {name}\n" + state_desc += "\n" + + state_desc += "What do you do? (Output only 'HIT' or 'STAND')" + + chat = [ + {"role": "system", "content": system}, + {"role": "user", "content": state_desc}, + ] + + return tokenizer.apply_chat_template( + chat, tokenize=False, add_generation_prompt=True + ) + + +def parse_action(response_text: str, legal_actions: list[int]) -> int: + """ + Parse action from model's text response. + + Args: + response_text: Model's generated text + legal_actions: List of legal action IDs + + Returns: + Action ID (0=HIT, 1=STAND) + """ + text_lower = response_text.lower().strip() + + if "hit" in text_lower: + action_id = 0 + elif "stand" in text_lower: + action_id = 1 + else: + action_id = 1 # Default: STAND + + # Ensure action is legal + if action_id not in legal_actions: + action_id = legal_actions[0] + + return action_id + + +# ============================================================================ +# Forge Actors +# ============================================================================ + + +@dataclass +class BlackJackReward(ForgeActor): + """Reward actor for evaluating game outcomes.""" + + @endpoint + async def evaluate_response( + self, prompt: str, response: str, game_reward: float + ) -> float: + """ + Evaluate episode reward with optional shaping. + + Args: + prompt: Game state prompt + response: Model's action + game_reward: Raw game outcome (+1/-1/0) + + Returns: + Shaped reward value + """ + # Base reward from game outcome + reward = float(game_reward) + + # Optional reward shaping: Scale up wins + if game_reward > 0: + reward = 2.0 # Make wins more valuable + elif game_reward == 0: + reward = 0.5 # Pushes better than losses + + record_metric("reward/evaluate_response/avg_reward", reward, Reduce.MEAN) + record_metric("reward/evaluate_response/sum_reward", reward, Reduce.SUM) + + return reward + + +@dataclass +class ComputeAdvantages(ForgeActor): + """Actor for computing group-relative advantages.""" + + @endpoint + async def compute(self, group: Group) -> list[float]: + """ + Compute advantages normalized by group statistics. + + Args: + group: List of episodes from same rollout + + Returns: + List of advantage values + """ + rewards = torch.tensor([[e.reward for e in group]]) + mean = rewards.mean(1, keepdim=True) + std = rewards.std(1, keepdim=True) + advantages = (rewards - mean) / (std + 1e-4) + return advantages.squeeze(0).tolist() + + +@dataclass +class EnvironmentActor(ForgeActor): + """Actor that manages OpenEnv connections and tokenizer.""" + + server_url: str = "http://localhost:8004" + model: str = "Qwen/Qwen2.5-1.5B-Instruct" + + @endpoint + def setup(self): + """Initialize tokenizer.""" + self._tokenizer = get_tokenizer(self.model) + print(f"EnvironmentActor initialized (server: {self.server_url})") + + @endpoint + async def get_tokenizer(self): + """Get tokenizer instance.""" + return self._tokenizer + + @endpoint + async def pad_token(self): + """Get padding token ID.""" + return self._tokenizer.pad_token_id + + +# Alias for backwards compatibility +BlackJackEnvActor = EnvironmentActor + + +# ============================================================================ +# Logging and Utilities +# ============================================================================ + + +def setup_game_logger(log_dir: str = "game_logs"): + """ + Setup detailed game logging to file. + + Args: + log_dir: Directory for log files + + Returns: + Logging function + """ + Path(log_dir).mkdir(exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = Path(log_dir) / f"games_{timestamp}.log" + + def log(message: str): + """Write message to log file and console.""" + with open(log_file, "a") as f: + f.write(f"{message}\n") + print(message) + + log("=" * 80) + log(f"GRPO Training - Game Log Started at {datetime.now()}") + log("=" * 80) + log("") + + return log + + +async def drop_weights(version: int): + """ + Drop old model weights from torchstore. + + Args: + version: Weight version to drop + """ + print(f"Dropping weights @ version {version}") + start_time = time.perf_counter() + + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + dcp_key = get_dcp_whole_state_dict_key(version) + + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + dcp_handle.drop() + + for key in matching_keys: + await ts.delete(key) + + elapsed = time.perf_counter() - start_time + print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") + + +# ============================================================================ +# Game Playing Logic +# ============================================================================ + + +async def play_game( + game_idx: int, + game_id: str, + server_url: str, + policy, + tokenizer, + game_log, + rollout_count: int = 0 +): + """ + Play a single game and collect episode data. + + Args: + game_idx: Index of this game in the rollout + game_id: Unique game identifier + server_url: OpenEnv server URL + policy: Policy (Generator) for action selection + tokenizer: Tokenizer for prompt formatting + game_log: Logging function + rollout_count: Current rollout iteration + + Returns: + List of step results with prompts, responses, and final reward + """ + env = OpenSpielEnv(base_url=server_url) + + game_log("") + game_log("=" * 80) + game_log(f"šŸŽ® GAME {game_idx + 1} (Rollout #{rollout_count + 1}) - ID: {game_id}") + game_log("=" * 80) + + try: + result = env.reset() + obs = result.observation + done = False + step_num = 0 + action_history = [] + game_steps = [] + + while not done and step_num < 10: # Max 10 steps per game + # Format prompt + prompt = format_prompt(step_num, action_history, tokenizer) + + game_log(f"\n--- Step {step_num + 1} ---") + game_log(f"Legal actions: {obs.legal_actions}") + game_log(f"\nPrompt sent to model:") + game_log("-" * 40) + game_log(prompt) + game_log("-" * 40) + + # Generate action with policy + responses = await policy.generate.route(prompt) + response = responses[0] + + game_log(f"\nšŸ¤– Model response: '{response.text}'") + + # Parse and execute action + action_id = parse_action(response.text, obs.legal_actions) + action_name = "HIT" if action_id == 0 else "STAND" + action_history.append((action_id, action_name)) + + game_log(f"āž”ļø Parsed action: {action_name} (action_id={action_id})") + + # Store step data (reward assigned later) + game_steps.append({ + "step_num": step_num, + "prompt": prompt, + "response": response, + }) + + # Take action in environment + result = env.step(OpenSpielAction(action_id=action_id, game_name="blackjack")) + obs = result.observation + done = result.done + + if done: + game_log(f"šŸ Game ended!") + + step_num += 1 + + # Get final game outcome + final_game_reward = result.reward # +1 (win), -1 (loss), or 0 (push) + + outcome_emoji = "šŸ†" if final_game_reward > 0 else ("šŸ’€" if final_game_reward < 0 else "šŸ¤") + outcome_text = "WIN" if final_game_reward > 0 else ("LOSS" if final_game_reward < 0 else "PUSH") + + game_log("") + game_log(f"{outcome_emoji} FINAL OUTCOME: {outcome_text} (reward={final_game_reward})") + game_log(f"šŸ“Š Game length: {len(game_steps)} steps") + game_log(f"šŸŽ² Action sequence: {' → '.join([name for _, name in action_history])}") + + # Assign final reward to all steps + all_step_results = [] + for step_data in game_steps: + all_step_results.append({ + "game_id": game_id, + "final_reward": final_game_reward, + **step_data, + }) + + # Record metrics + record_metric("game/count_games_played", 1, Reduce.SUM) + record_metric("game/avg_game_length", len(game_steps), Reduce.MEAN) + record_metric("game/outcome", final_game_reward, Reduce.MEAN) + + return all_step_results + + finally: + env.close() + + +# Alias for backwards compatibility +play_blackjack_game = play_game + + +# ============================================================================ +# OpenEnv Helper Functions (For Tutorial/Exploration) +# ============================================================================ + + +def show_openenv_observation(observation): + """ + Pretty print an OpenEnv observation. + + Args: + observation: OpenEnv observation object + """ + print("šŸ“Š Observation:") + print(f" Game phase: {observation.game_phase}") + print(f" Legal actions: {observation.legal_actions}") + print(f" Info state shape: {len(observation.info_state)}") + print(f" Info state (first 10): {observation.info_state[:10]}") + + +def play_random_policy(server_url: str, num_games: int = 100): + """ + Benchmark random policy on OpenEnv environment. + + Args: + server_url: OpenEnv server URL + num_games: Number of games to play + + Returns: + dict with statistics + """ + import random + + env = OpenSpielEnv(base_url=server_url) + wins = losses = pushes = 0 + + for _ in range(num_games): + result = env.reset() + done = False + step_count = 0 + + while not done and step_count < 10: + # Random action + action = random.choice(result.observation.legal_actions) + result = env.step(OpenSpielAction(action_id=action, game_name="blackjack")) + done = result.done + step_count += 1 + + # Count outcome + if result.reward > 0: + wins += 1 + elif result.reward < 0: + losses += 1 + else: + pushes += 1 + + env.close() + + return { + "wins": wins, + "losses": losses, + "pushes": pushes, + "win_rate": wins / num_games, + "total_games": num_games + } + + +def play_heuristic_policy(server_url: str, num_games: int = 100): + """ + Benchmark basic strategy heuristic on OpenEnv environment. + + Simple heuristic: HIT if < 17, STAND otherwise + + Args: + server_url: OpenEnv server URL + num_games: Number of games to play + + Returns: + dict with statistics + """ + # This is a simplified heuristic - real basic strategy is more complex + # For tutorial purposes only + return play_random_policy(server_url, num_games) # Placeholder + + +# ============================================================================ +# Forge Training Abstraction (Hides Complexity) +# ============================================================================ + + +class GRPOTrainer: + """ + Simplified interface for GRPO training that hides Forge complexity. + + This class wraps all Forge infrastructure (provisioner, services, actors) + and exposes a clean interface for the tutorial notebook. + """ + + def __init__(self, services: dict, cfg: DictConfig): + """ + Initialize trainer (called by setup_forge_training). + + Args: + services: Dict of initialized Forge services/actors + cfg: Training configuration + """ + self._services = services + self._cfg = cfg + self._metrics = [] + self._shutdown_event = asyncio.Event() + + @property + def policy(self): + """Access the trained policy for playing games.""" + return self._services['policy'] + + async def run(self, steps: int) -> dict: + """ + Run GRPO training for specified steps. + + Args: + steps: Number of training steps + + Returns: + Training metrics dict + """ + # Unpack services + policy = self._services['policy'] + trainer = self._services['trainer'] + replay_buffer = self._services['replay_buffer'] + compute_advantages = self._services['compute_advantages'] + ref_model = self._services['ref_model'] + reward_actor = self._services['reward_actor'] + tokenizer = self._services['tokenizer'] + pad_id = self._services['pad_id'] + mlogger = self._services['mlogger'] + + # Training parameters + group_size = self._cfg.group_size + max_req_tokens = self._cfg.max_req_tokens + max_res_tokens = self._cfg.max_res_tokens + server_url = self._cfg.get("blackjack_env", {}).get("server_url", "http://localhost:8004") + + game_log = setup_game_logger() + + # Training metrics + metrics = { + 'iterations': [], + 'win_rates': [], + 'losses': [] + } + + # Rollout loop (copy-pasted from blackjack_main_fixed.py) + async def continuous_rollouts(): + rollout_count = 0 + + while not self._shutdown_event.is_set(): + all_step_results = [] + + # Play games + for game_idx in range(group_size): + game_id = str(uuid.uuid4())[:8] + step_results = await play_game( + game_idx=game_idx, + game_id=game_id, + server_url=server_url, + policy=policy, + tokenizer=tokenizer, + game_log=game_log, + rollout_count=rollout_count + ) + all_step_results.extend(step_results) + + # Create episodes + episodes = [] + input_ids = torch.ones( + (len(all_step_results), max_req_tokens + max_res_tokens), + dtype=torch.long, + ) + + for i, step_result in enumerate(all_step_results): + episode = Episode( + episode_id=str(uuid.uuid4()), + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + game_id=step_result["game_id"], + step_in_game=step_result["step_num"], + completion=step_result["response"], + ) + + episode.reward = await reward_actor.evaluate_response.route( + prompt=step_result["prompt"], + response=step_result["response"].text, + game_reward=step_result["final_reward"], + ) + + episodes.append(episode) + input_ids[i, :max_req_tokens] = episode.request_tensor + input_ids[i, max_req_tokens:] = episode.response_tensor + + # Get reference logprobs + ref_logprobs = await ref_model.forward.route( + input_ids, max_req_tokens, return_logprobs=True + ) + for i, episode in enumerate(episodes): + episode.ref_logprobs = ref_logprobs[i] + + # Compute advantages + advantages = await compute_advantages.compute.call_one(episodes) + for episode, advantage in zip(episodes, advantages): + episode.advantage = advantage + await replay_buffer.add.call_one(episode) + + rollout_count += 1 + + # Track win rate + wins = sum(1 for e in episodes if e.reward > 0) + win_rate = wins / len(episodes) if episodes else 0 + print(f"šŸ“Š Rollout {rollout_count}: {len(episodes)} episodes, Win rate: {win_rate:.1%}") + + # Training loop (copy-pasted from blackjack_main_fixed.py) + async def continuous_training(): + training_step = 0 + + while training_step < steps: + batch = await replay_buffer.sample.call_one(curr_policy_version=training_step) + if batch is None: + await asyncio.sleep(0.1) + continue + + inputs, targets = batch + await trainer.train_step.call(inputs, targets) + training_step += 1 + + await trainer.push_weights.call(training_step) + await policy.update_weights.fanout(training_step) + + if training_step >= 2: + await drop_weights(training_step - 1) + + await mlogger.flush.call_one(training_step) + + print(f"āœ… Training step {training_step}/{steps}") + + print(f"\nšŸŽ‰ Training complete!") + + # Run both loops + rollout_task = asyncio.create_task(continuous_rollouts()) + training_task = asyncio.create_task(continuous_training()) + + try: + await training_task + finally: + self._shutdown_event.set() + try: + await asyncio.wait_for(rollout_task, timeout=5) + except asyncio.TimeoutError: + rollout_task.cancel() + + return metrics + + async def shutdown(self): + """Shutdown all Forge services.""" + await shutdown() + + +async def setup_forge_training(config_path: str) -> GRPOTrainer: + """ + Setup Forge GRPO training infrastructure. + + This function hides all the complexity of initializing Forge services. + + Args: + config_path: Path to YAML config file + + Returns: + GRPOTrainer instance with simple interface + """ + from omegaconf import OmegaConf + + # Load config + cfg = OmegaConf.load(config_path) + + print("šŸ—ļø Initializing Forge infrastructure...\n") + + # Initialize provisioner + if cfg.get("provisioner", None) is not None: + provisioner = await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + else: + provisioner = await init_provisioner() + print(" āœ… Provisioner") + + # Initialize metric logging + metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) + mlogger = await get_or_create_metric_logger() + await mlogger.init_backends.call_one(metric_logging_cfg) + print(" āœ… Metric Logger") + + # Initialize all services (copy-pasted from blackjack_main_fixed.py) + print("\n šŸš€ Initializing services...") + ( + env_actor, + policy, + trainer, + replay_buffer, + compute_advantages, + ref_model, + reward_actor, + ) = await asyncio.gather( + EnvironmentActor.options(**cfg.actors.get("blackjack_env", cfg.actors.get("env_actor", {}))).as_actor(**cfg.get("blackjack_env", {})), + Generator.options(**cfg.services.policy).as_service(**cfg.policy), + RLTrainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer, loss=simple_grpo_loss), + ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(**cfg.replay_buffer, collate=collate), + ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(), + ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), + BlackJackReward.options(**cfg.services.reward_actor).as_service(), + ) + + print(" āœ… All services initialized") + + # Initialize torchstore + trainer_num_procs = cfg.actors.trainer["procs"] + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), + strategy=ts.LocalRankStrategy(), + ) + print(" āœ… Torchstore") + + # Get tokenizer + tokenizer = await env_actor.get_tokenizer.call_one() + pad_id = await env_actor.pad_token.call_one() + + print("\nāœ… Forge ready for training!\n") + + # Package services + services = { + 'provisioner': provisioner, + 'mlogger': mlogger, + 'env_actor': env_actor, + 'policy': policy, + 'trainer': trainer, + 'replay_buffer': replay_buffer, + 'compute_advantages': compute_advantages, + 'ref_model': ref_model, + 'reward_actor': reward_actor, + 'tokenizer': tokenizer, + 'pad_id': pad_id, + } + + return GRPOTrainer(services, cfg)