diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index ef753e07..bc1e55fd 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -73,6 +73,8 @@ jobs: dockerfile: src/envs/chat_env/server/Dockerfile - name: coding-env dockerfile: src/envs/coding_env/server/Dockerfile + - name: atari-env + dockerfile: src/envs/atari_env/server/Dockerfile steps: - name: Checkout code diff --git a/examples/atari_simple.py b/examples/atari_simple.py new file mode 100644 index 00000000..e0e6a743 --- /dev/null +++ b/examples/atari_simple.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +""" +Simple example demonstrating Atari Environment usage. + +This example shows how to: +1. Connect to an Atari environment +2. Reset the environment +3. Take random actions +4. Process observations + +Usage: + # First, start the server: + python -m envs.atari_env.server.app + + # Then run this script: + python examples/atari_simple.py +""" + +import numpy as np +from envs.atari_env import AtariEnv, AtariAction + + +def main(): + """Run a simple Atari episode.""" + # Connect to the Atari environment server + print("Connecting to Atari environment...") + env = AtariEnv(base_url="http://localhost:8000") + + try: + # Reset the environment + print("\nResetting environment...") + result = env.reset() + print(f"Screen shape: {result.observation.screen_shape}") + print(f"Legal actions: {result.observation.legal_actions}") + print(f"Lives: {result.observation.lives}") + + # Run a few steps with random actions + print("\nTaking random actions...") + episode_reward = 0 + steps = 0 + + for step in range(100): + # Random action + action_id = np.random.choice(result.observation.legal_actions) + + # Take action + result = env.step(AtariAction(action_id=action_id)) + + episode_reward += result.reward or 0 + steps += 1 + + # Print progress + if step % 10 == 0: + print( + f"Step {step}: reward={result.reward:.2f}, " + f"lives={result.observation.lives}, done={result.done}" + ) + + if result.done: + print(f"\nEpisode finished after {steps} steps!") + break + + print(f"\nTotal episode reward: {episode_reward:.2f}") + + # Get environment state + state = env.state() + print(f"\nEnvironment state:") + print(f" Game: {state.game_name}") + print(f" Episode: {state.episode_id}") + print(f" Steps: {state.step_count}") + print(f" Obs type: {state.obs_type}") + + finally: + # Cleanup + print("\nClosing environment...") + env.close() + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/src/envs/atari_env/README.md b/src/envs/atari_env/README.md new file mode 100644 index 00000000..4171375a --- /dev/null +++ b/src/envs/atari_env/README.md @@ -0,0 +1,383 @@ +# Atari Environment + +Integration of Atari 2600 games with the OpenEnv framework via the Arcade Learning Environment (ALE). ALE provides access to 100+ classic Atari games for RL research. + +## Supported Games + +ALE supports 100+ Atari 2600 games including: + +### Popular Games +- **Pong** - Classic two-player tennis +- **Breakout** - Break bricks with a ball +- **Space Invaders** - Shoot descending aliens +- **Pac-Man / Ms. Pac-Man** - Navigate mazes and eat pellets +- **Asteroids** - Destroy asteroids in space +- **Defender** - Side-scrolling space shooter +- **Centipede** - Shoot segmented centipede +- **Donkey Kong** - Jump over barrels to save princess +- **Frogger** - Cross road and river safely +- **Q*bert** - Jump on pyramid cubes + +And many more! For a complete list, see [ALE documentation](https://ale.farama.org/environments/complete_list/). + +## Architecture + +``` +┌────────────────────────────────────┐ +│ RL Training Code (Client) │ +│ AtariEnv.step(action) │ +└──────────────┬─────────────────────┘ + │ HTTP +┌──────────────▼─────────────────────┐ +│ FastAPI Server (Docker) │ +│ AtariEnvironment │ +│ ├─ Wraps ALEInterface │ +│ ├─ Handles observations │ +│ └─ Action execution │ +└────────────────────────────────────┘ +``` + +## Installation & Usage + +### Option 1: Local Development (without Docker) + +**Requirements:** +- Python 3.11+ +- ale-py installed: `pip install ale-py` + +```python +from envs.atari_env import AtariEnv, AtariAction + +# Start local server manually +# python -m envs.atari_env.server.app + +# Connect to local server +env = AtariEnv(base_url="http://localhost:8000") + +# Reset environment +result = env.reset() +print(f"Screen shape: {result.observation.screen_shape}") +print(f"Legal actions: {result.observation.legal_actions}") +print(f"Lives: {result.observation.lives}") + +# Take actions +for _ in range(10): + action_id = 2 # UP action + result = env.step(AtariAction(action_id=action_id, game_name="pong")) + print(f"Reward: {result.reward}, Done: {result.done}") + if result.done: + break + +# Cleanup +env.close() +``` + +### Option 2: Docker (Recommended) + +**Build Atari image:** + +```bash +cd OpenEnv + +# Build the image +docker build \ + -f src/envs/atari_env/server/Dockerfile \ + -t atari-env:latest \ + . +``` + +**Run specific games:** + +```bash +# Pong (default) +docker run -p 8000:8000 atari-env:latest + +# Breakout +docker run -p 8000:8000 -e ATARI_GAME=breakout atari-env:latest + +# Space Invaders with grayscale observation +docker run -p 8000:8000 \ + -e ATARI_GAME=space_invaders \ + -e ATARI_OBS_TYPE=grayscale \ + atari-env:latest + +# Ms. Pac-Man with full action space +docker run -p 8000:8000 \ + -e ATARI_GAME=ms_pacman \ + -e ATARI_FULL_ACTION_SPACE=true \ + atari-env:latest +``` + +**Use with from_docker_image():** + +```python +from envs.atari_env import AtariEnv, AtariAction +import numpy as np + +# Automatically starts container +env = AtariEnv.from_docker_image("atari-env:latest") + +result = env.reset() +result = env.step(AtariAction(action_id=2)) # UP + +# Reshape screen for visualization +screen = np.array(result.observation.screen).reshape(result.observation.screen_shape) +print(f"Screen shape: {screen.shape}") # (210, 160, 3) for RGB + +env.close() # Stops container +``` + +## Observation Types + +### 1. RGB (Default) +- **Shape**: [210, 160, 3] +- **Description**: Full-color screen observation +- **Usage**: Most realistic, good for vision-based learning + +```python +docker run -p 8000:8000 -e ATARI_OBS_TYPE=rgb atari-env:latest +``` + +### 2. Grayscale +- **Shape**: [210, 160] +- **Description**: Grayscale screen observation +- **Usage**: Reduced dimensionality, faster processing + +```python +docker run -p 8000:8000 -e ATARI_OBS_TYPE=grayscale atari-env:latest +``` + +### 3. RAM +- **Shape**: [128] +- **Description**: Raw 128-byte Atari 2600 RAM contents +- **Usage**: Compact representation, useful for specific research + +```python +docker run -p 8000:8000 -e ATARI_OBS_TYPE=ram atari-env:latest +``` + +## Action Spaces + +### Minimal Action Set (Default) +Game-specific minimal actions (typically 4-9 actions). +- Pong: 6 actions (NOOP, FIRE, UP, DOWN, etc.) +- Breakout: 4 actions (NOOP, FIRE, LEFT, RIGHT) + +```python +docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=false atari-env:latest +``` + +### Full Action Set +All 18 possible Atari 2600 actions: +0. NOOP +1. FIRE +2. UP +3. RIGHT +4. LEFT +5. DOWN +6. UPRIGHT +7. UPLEFT +8. DOWNRIGHT +9. DOWNLEFT +10. UPFIRE +11. RIGHTFIRE +12. LEFTFIRE +13. DOWNFIRE +14. UPRIGHTFIRE +15. UPLEFTFIRE +16. DOWNRIGHTFIRE +17. DOWNLEFTFIRE + +```python +docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=true atari-env:latest +``` + +## Configuration + +### Environment Variables + +- `ATARI_GAME`: Game name (default: "pong") +- `ATARI_OBS_TYPE`: Observation type - "rgb", "grayscale", "ram" (default: "rgb") +- `ATARI_FULL_ACTION_SPACE`: Use full action space - "true"/"false" (default: "false") +- `ATARI_MODE`: Game mode (optional, game-specific) +- `ATARI_DIFFICULTY`: Game difficulty (optional, game-specific) +- `ATARI_REPEAT_ACTION_PROB`: Sticky action probability 0.0-1.0 (default: "0.0") +- `ATARI_FRAMESKIP`: Frames to skip per action (default: "4") + +### Example: Breakout with Custom Settings + +```bash +docker run -p 8000:8000 \ + -e ATARI_GAME=breakout \ + -e ATARI_OBS_TYPE=grayscale \ + -e ATARI_FULL_ACTION_SPACE=true \ + -e ATARI_REPEAT_ACTION_PROB=0.25 \ + -e ATARI_FRAMESKIP=4 \ + atari-env:latest +``` + +## API Reference + +### AtariAction + +```python +@dataclass +class AtariAction(Action): + action_id: int # Action index to execute + game_name: str = "pong" # Game name + obs_type: str = "rgb" # Observation type + full_action_space: bool = False # Full or minimal action space +``` + +### AtariObservation + +```python +@dataclass +class AtariObservation(Observation): + screen: List[int] # Flattened screen pixels + screen_shape: List[int] # Original screen shape + legal_actions: List[int] # Legal action indices + lives: int # Lives remaining + episode_frame_number: int # Frame # in episode + frame_number: int # Total frame # + done: bool # Episode finished + reward: Optional[float] # Reward from last action +``` + +### AtariState + +```python +@dataclass +class AtariState(State): + episode_id: str # Unique episode ID + step_count: int # Number of steps + game_name: str # Game name + obs_type: str # Observation type + full_action_space: bool # Action space type + mode: Optional[int] # Game mode + difficulty: Optional[int] # Game difficulty + repeat_action_probability: float # Sticky action prob + frameskip: int # Frameskip setting +``` + +## Example Script + +```python +#!/usr/bin/env python3 +"""Example training loop with Atari environment.""" + +import numpy as np +from envs.atari_env import AtariEnv, AtariAction + +# Start environment +env = AtariEnv.from_docker_image("atari-env:latest") + +# Training loop +for episode in range(10): + result = env.reset() + episode_reward = 0 + steps = 0 + + while not result.done: + # Random policy (replace with your RL agent) + action_id = np.random.choice(result.observation.legal_actions) + + # Take action + result = env.step(AtariAction(action_id=action_id)) + + episode_reward += result.reward or 0 + steps += 1 + + # Reshape screen for processing + screen = np.array(result.observation.screen).reshape( + result.observation.screen_shape + ) + + # Your RL training code here + # ... + + print(f"Episode {episode}: reward={episode_reward:.2f}, steps={steps}") + +env.close() +``` + +## Testing + +### Local Testing + +```bash +# Install dependencies +pip install ale-py fastapi uvicorn requests + +# Start server +cd /Users/sanyambhutani/OpenEnv/OpenEnv +export PYTHONPATH=/Users/sanyambhutani/OpenEnv/OpenEnv/src +python -m envs.atari_env.server.app + +# Test from another terminal +python -c " +from envs.atari_env import AtariEnv, AtariAction +env = AtariEnv(base_url='http://localhost:8000') +result = env.reset() +print(f'Initial obs: {result.observation.screen_shape}') +result = env.step(AtariAction(action_id=2)) +print(f'After step: reward={result.reward}, done={result.done}') +env.close() +" +``` + +### Docker Testing + +```bash +# Build and run +docker build -f src/envs/atari_env/server/Dockerfile -t atari-env:latest . +docker run -p 8000:8000 atari-env:latest + +# Test in another terminal +curl http://localhost:8000/health +curl -X POST http://localhost:8000/reset +``` + +## Popular Games and Their Characteristics + +| Game | Minimal Actions | Lives | Difficulty | Notes | +|------|----------------|-------|-----------|-------| +| Pong | 6 | 1 | Low | Good for learning basics | +| Breakout | 4 | 5 | Medium | Classic RL benchmark | +| Space Invaders | 6 | 3 | Medium | Shooting game | +| Ms. Pac-Man | 9 | 3 | High | Complex navigation | +| Asteroids | 14 | 3 | Medium | Continuous shooting | +| Montezuma's Revenge | 18 | 5 | Very High | Exploration challenge | +| Pitfall | 18 | 1 | High | Platformer | +| Seaquest | 18 | 3 | High | Submarine rescue | + +## Limitations & Notes + +- **Frame perfect timing**: Some games require precise timing +- **Exploration**: Games like Montezuma's Revenge are notoriously difficult +- **Observation delay**: HTTP adds minimal latency vs local gym +- **Determinism**: Set `ATARI_REPEAT_ACTION_PROB=0.0` for deterministic behavior +- **ROMs**: All ROMs are bundled with ale-py package + +## References + +- [Arcade Learning Environment Paper (2013)](https://jair.org/index.php/jair/article/view/10819) +- [ALE GitHub](https://github.com/Farama-Foundation/Arcade-Learning-Environment) +- [ALE Documentation](https://ale.farama.org/) +- [Gymnasium Atari Environments](https://gymnasium.farama.org/environments/atari/) + +## Citation + +If you use ALE in your research, please cite: + +```bibtex +@Article{bellemare13arcade, + author = {{Bellemare}, M.~G. and {Naddaf}, Y. and {Veness}, J. and {Bowling}, M.}, + title = {The Arcade Learning Environment: An Evaluation Platform for General Agents}, + journal = {Journal of Artificial Intelligence Research}, + year = "2013", + month = "jun", + volume = "47", + pages = "253--279", +} +``` diff --git a/src/envs/atari_env/__init__.py b/src/envs/atari_env/__init__.py new file mode 100644 index 00000000..5ea68431 --- /dev/null +++ b/src/envs/atari_env/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Atari Environment for OpenEnv. + +This module provides OpenEnv integration for Atari 2600 games via the +Arcade Learning Environment (ALE). + +Example: + >>> from envs.atari_env import AtariEnv, AtariAction + >>> + >>> # Connect to a running server or start via Docker + >>> env = AtariEnv.from_docker_image("atari-env:latest") + >>> + >>> # Reset and interact + >>> result = env.reset() + >>> result = env.step(AtariAction(action_id=2)) # UP + >>> print(result.reward, result.done) + >>> + >>> # Cleanup + >>> env.close() +""" + +from .client import AtariEnv +from .models import AtariAction, AtariObservation, AtariState + +__all__ = ["AtariEnv", "AtariAction", "AtariObservation", "AtariState"] diff --git a/src/envs/atari_env/client.py b/src/envs/atari_env/client.py new file mode 100644 index 00000000..60eab107 --- /dev/null +++ b/src/envs/atari_env/client.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Atari Environment HTTP Client. + +This module provides the client for connecting to an Atari Environment server +over HTTP. +""" + +from __future__ import annotations + +from typing import Any, Dict, TYPE_CHECKING + +from core.http_env_client import HTTPEnvClient +from core.types import StepResult + +from .models import AtariAction, AtariObservation, AtariState + +if TYPE_CHECKING: + from core.containers.runtime import ContainerProvider + + +class AtariEnv(HTTPEnvClient[AtariAction, AtariObservation]): + """ + HTTP client for Atari Environment. + + This client connects to an AtariEnvironment HTTP server and provides + methods to interact with it: reset(), step(), and state access. + + Example: + >>> # Connect to a running server + >>> client = AtariEnv(base_url="http://localhost:8000") + >>> result = client.reset() + >>> print(result.observation.screen_shape) + >>> + >>> # Take an action + >>> result = client.step(AtariAction(action_id=2)) # UP + >>> print(result.reward, result.done) + + Example with Docker: + >>> # Automatically start container and connect + >>> client = AtariEnv.from_docker_image("atari-env:latest") + >>> result = client.reset() + >>> result = client.step(AtariAction(action_id=0)) # NOOP + """ + + def _step_payload(self, action: AtariAction) -> Dict[str, Any]: + """ + Convert AtariAction to JSON payload for step request. + + Args: + action: AtariAction instance. + + Returns: + Dictionary representation suitable for JSON encoding. + """ + return { + "action_id": action.action_id, + "game_name": action.game_name, + "obs_type": action.obs_type, + "full_action_space": action.full_action_space, + } + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AtariObservation]: + """ + Parse server response into StepResult[AtariObservation]. + + Args: + payload: JSON response from server. + + Returns: + StepResult with AtariObservation. + """ + obs_data = payload.get("observation", {}) + + observation = AtariObservation( + screen=obs_data.get("screen", []), + screen_shape=obs_data.get("screen_shape", []), + legal_actions=obs_data.get("legal_actions", []), + lives=obs_data.get("lives", 0), + episode_frame_number=obs_data.get("episode_frame_number", 0), + frame_number=obs_data.get("frame_number", 0), + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> AtariState: + """ + Parse server response into AtariState object. + + Args: + payload: JSON response from /state endpoint. + + Returns: + AtariState object with environment state information. + """ + return AtariState( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + game_name=payload.get("game_name", "unknown"), + obs_type=payload.get("obs_type", "rgb"), + full_action_space=payload.get("full_action_space", False), + mode=payload.get("mode"), + difficulty=payload.get("difficulty"), + repeat_action_probability=payload.get("repeat_action_probability", 0.0), + frameskip=payload.get("frameskip", 4), + ) diff --git a/src/envs/atari_env/models.py b/src/envs/atari_env/models.py new file mode 100644 index 00000000..1938172e --- /dev/null +++ b/src/envs/atari_env/models.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Data models for Atari Environment. + +This module defines the Action, Observation, and State types for Atari games +via the Arcade Learning Environment (ALE). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional + +from core.env_server import Action, Observation, State + + +@dataclass +class AtariAction(Action): + """ + Action for Atari environments. + + Attributes: + action_id: The integer action ID to take (from legal_actions). + game_name: Name of the Atari game (e.g., "pong", "breakout", "space_invaders"). + obs_type: Observation type ("rgb", "grayscale", or "ram"). + full_action_space: Whether to use full (18 actions) or minimal action space. + """ + action_id: int + game_name: str = "pong" + obs_type: Literal["rgb", "grayscale", "ram"] = "rgb" + full_action_space: bool = False + + +@dataclass +class AtariObservation(Observation): + """ + Observation from Atari environment. + + This represents what the agent sees after taking an action. + + Attributes: + screen: Screen observation as a flattened list of pixels. + Shape depends on obs_type: + - rgb: [210, 160, 3] flattened + - grayscale: [210, 160] flattened + - ram: [128] (RAM contents) + screen_shape: Original shape of the screen before flattening. + legal_actions: List of legal action IDs the agent can take. + lives: Number of lives remaining. + episode_frame_number: Frame number within current episode. + frame_number: Total frame number since environment creation. + """ + screen: List[int] + screen_shape: List[int] + legal_actions: List[int] + lives: int = 0 + episode_frame_number: int = 0 + frame_number: int = 0 + + +@dataclass +class AtariState(State): + """ + State for Atari environment. + + Attributes: + game_name: Name of the Atari game. + obs_type: Observation type ("rgb", "grayscale", or "ram"). + full_action_space: Whether using full or minimal action space. + mode: Game mode (if applicable). + difficulty: Game difficulty (if applicable). + repeat_action_probability: Probability of repeating previous action (sticky actions). + frameskip: Number of frames to skip per action. + """ + game_name: str = "pong" + obs_type: Literal["rgb", "grayscale", "ram"] = "rgb" + full_action_space: bool = False + mode: Optional[int] = None + difficulty: Optional[int] = None + repeat_action_probability: float = 0.0 + frameskip: int = 4 diff --git a/src/envs/atari_env/server/Dockerfile b/src/envs/atari_env/server/Dockerfile new file mode 100644 index 00000000..3ad6d14d --- /dev/null +++ b/src/envs/atari_env/server/Dockerfile @@ -0,0 +1,43 @@ +# Dockerfile for Atari Environment +# This image provides Atari 2600 games via the Arcade Learning Environment (ALE) + +# Configurable base image - defaults to local build, can be overridden for CI/CD +# Base image provides: fastapi, uvicorn, requests, curl, PYTHONPATH=/app/src +# +# Local build: docker build -t envtorch-base:latest -f src/core/containers/images/Dockerfile . +# docker build -f src/envs/atari_env/server/Dockerfile -t atari-env:latest . +# +# CI/CD build: docker build --build-arg BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest \ +# -f src/envs/atari_env/server/Dockerfile -t atari-env:latest . +ARG BASE_IMAGE=envtorch-base:latest +FROM ${BASE_IMAGE} + +# Install ALE-specific dependencies +# ale-py includes all Atari ROMs by default and requires gymnasium +RUN pip install --no-cache-dir \ + gymnasium>=0.29.0 \ + ale-py>=0.8.0 \ + numpy>=1.24.0 + +# Copy OpenEnv core (base image already set WORKDIR=/app) +COPY src/core/ /app/src/core/ + +# Copy Atari environment code +COPY src/envs/atari_env/ /app/src/envs/atari_env/ + +# Atari-specific environment variables (can be overridden at runtime) +ENV ATARI_GAME=pong +ENV ATARI_OBS_TYPE=rgb +ENV ATARI_FULL_ACTION_SPACE=false +ENV ATARI_REPEAT_ACTION_PROB=0.0 +ENV ATARI_FRAMESKIP=4 + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the FastAPI server +CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/src/envs/atari_env/server/__init__.py b/src/envs/atari_env/server/__init__.py new file mode 100644 index 00000000..266366ba --- /dev/null +++ b/src/envs/atari_env/server/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Atari Environment Server. + +Server-side implementation of Atari environment for OpenEnv. +""" + +from .atari_environment import AtariEnvironment + +__all__ = ["AtariEnvironment"] diff --git a/src/envs/atari_env/server/app.py b/src/envs/atari_env/server/app.py new file mode 100644 index 00000000..8b586df2 --- /dev/null +++ b/src/envs/atari_env/server/app.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +FastAPI application for the Atari Environment. + +This module creates an HTTP server that exposes Atari games +over HTTP endpoints, making them compatible with HTTPEnvClient. + +Usage: + # Development (with auto-reload): + uvicorn envs.atari_env.server.app:app --reload --host 0.0.0.0 --port 8000 + + # Production: + uvicorn envs.atari_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4 + + # Or run directly: + python -m envs.atari_env.server.app + +Environment variables: + ATARI_GAME: Game name to serve (default: "pong") + ATARI_OBS_TYPE: Observation type (default: "rgb") + ATARI_FULL_ACTION_SPACE: Use full action space (default: "false") + ATARI_MODE: Game mode (optional) + ATARI_DIFFICULTY: Game difficulty (optional) + ATARI_REPEAT_ACTION_PROB: Sticky action probability (default: "0.0") + ATARI_FRAMESKIP: Frameskip (default: "4") +""" + +import os + +from core.env_server import create_fastapi_app + +from ..models import AtariAction, AtariObservation +from .atari_environment import AtariEnvironment + +# Get configuration from environment variables +game_name = os.getenv("ATARI_GAME", "pong") +obs_type = os.getenv("ATARI_OBS_TYPE", "rgb") +full_action_space = os.getenv("ATARI_FULL_ACTION_SPACE", "false").lower() == "true" +repeat_action_prob = float(os.getenv("ATARI_REPEAT_ACTION_PROB", "0.0")) +frameskip = int(os.getenv("ATARI_FRAMESKIP", "4")) + +# Optional parameters +mode = os.getenv("ATARI_MODE") +difficulty = os.getenv("ATARI_DIFFICULTY") + +# Convert to int if specified +mode = int(mode) if mode is not None else None +difficulty = int(difficulty) if difficulty is not None else None + +# Create the environment instance +env = AtariEnvironment( + game_name=game_name, + obs_type=obs_type, + full_action_space=full_action_space, + mode=mode, + difficulty=difficulty, + repeat_action_probability=repeat_action_prob, + frameskip=frameskip, +) + +# Create the FastAPI app with routes +app = create_fastapi_app(env, AtariAction, AtariObservation) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/envs/atari_env/server/atari_environment.py b/src/envs/atari_env/server/atari_environment.py new file mode 100644 index 00000000..6d6b5362 --- /dev/null +++ b/src/envs/atari_env/server/atari_environment.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Atari Environment Server Implementation. + +This module wraps ALE's ALEInterface and exposes it +via the OpenEnv Environment interface. +""" + +import uuid +from typing import Any, Dict, Literal, Optional + +from core.env_server import Action, Environment, Observation + +from ..models import AtariAction, AtariObservation, AtariState + +# Import ALE +try: + from ale_py import ALEInterface, roms + import numpy as np +except ImportError as e: + raise ImportError( + "ALE (Arcade Learning Environment) is not installed. " + "Please install it with: pip install ale-py" + ) from e + + +class AtariEnvironment(Environment): + """ + Atari Environment wrapper for OpenEnv. + + This environment wraps Atari 2600 games via the Arcade Learning Environment (ALE) + and provides a clean interface for RL training. + + Supported games include: pong, breakout, space_invaders, and 100+ others. + + Args: + game_name: Name of the Atari game (e.g., "pong", "breakout"). + obs_type: Observation type - "rgb", "grayscale", or "ram". + full_action_space: Use full action space (18 actions) vs minimal. + mode: Game mode (if applicable). + difficulty: Game difficulty (if applicable). + repeat_action_probability: Sticky action probability (default 0.0). + frameskip: Number of frames to skip per action (default 4). + + Example: + >>> env = AtariEnvironment("pong") + >>> obs = env.reset() + >>> print(obs.screen_shape) # [210, 160, 3] + >>> obs = env.step(AtariAction(action_id=2)) # UP + >>> print(obs.reward, obs.done) + """ + + def __init__( + self, + game_name: str = "pong", + obs_type: Literal["rgb", "grayscale", "ram"] = "rgb", + full_action_space: bool = False, + mode: Optional[int] = None, + difficulty: Optional[int] = None, + repeat_action_probability: float = 0.0, + frameskip: int = 4, + ): + """Initialize Atari environment.""" + super().__init__() + + self.game_name = game_name + self.obs_type = obs_type + self.full_action_space = full_action_space + self.mode = mode + self.difficulty = difficulty + self.repeat_action_probability = repeat_action_probability + self.frameskip = frameskip + + # Create ALE interface + self.ale = ALEInterface() + + # Configure ALE + from ale_py import LoggerMode + self.ale.setLoggerMode(LoggerMode.Error) # Error mode only + self.ale.setFloat("repeat_action_probability", repeat_action_probability) + + # Load ROM + try: + rom_path = roms.get_rom_path(game_name) + self.ale.loadROM(rom_path) + except Exception as e: + raise ValueError( + f"Failed to load Atari game '{game_name}': {e}\n" + f"Available games can be found via: ale_py.roms.list_roms()" + ) from e + + # Set mode and difficulty if specified + if mode is not None: + self.ale.setMode(mode) + if difficulty is not None: + self.ale.setDifficulty(difficulty) + + # Get action set + if full_action_space: + self._action_set = self.ale.getLegalActionSet() + else: + self._action_set = self.ale.getMinimalActionSet() + + # Get screen dimensions for observation space + self.screen_height, self.screen_width = self.ale.getScreenDims() + if obs_type == "rgb": + self.screen_shape = [self.screen_height, self.screen_width, 3] + elif obs_type == "grayscale": + self.screen_shape = [self.screen_height, self.screen_width] + elif obs_type == "ram": + self.screen_shape = [self.ale.getRAMSize()] + else: + raise ValueError(f"Invalid obs_type: {obs_type}") + + # Initialize state + self._state = AtariState( + game_name=game_name, + obs_type=obs_type, + full_action_space=full_action_space, + mode=mode, + difficulty=difficulty, + repeat_action_probability=repeat_action_probability, + frameskip=frameskip, + ) + + def reset(self) -> Observation: + """ + Reset the environment and return initial observation. + + Returns: + Initial observation for the agent. + """ + # Reset ALE + self.ale.reset_game() + + # Reset state tracking + self._state.episode_id = str(uuid.uuid4()) + self._state.step_count = 0 + + # Get initial observation + return self._make_observation() + + def step(self, action: Action) -> Observation: + """ + Execute agent's action and return resulting observation. + + Args: + action: AtariAction containing the action_id to execute. + + Returns: + Observation after action execution. + + Raises: + ValueError: If action is not an AtariAction. + """ + if not isinstance(action, AtariAction): + raise ValueError(f"Expected AtariAction, got {type(action)}") + + # Validate action_id + if action.action_id < 0 or action.action_id >= len(self._action_set): + raise ValueError( + f"Invalid action_id: {action.action_id}. " + f"Valid range: [0, {len(self._action_set) - 1}]" + ) + + # Get actual ALE action + ale_action = self._action_set[action.action_id] + + # Execute action with frameskip + total_reward = 0.0 + for _ in range(self.frameskip): + total_reward += self.ale.act(ale_action) + if self.ale.game_over(): + break + + self._state.step_count += 1 + + # Get observation + obs = self._make_observation() + obs.reward = total_reward + + return obs + + @property + def state(self) -> AtariState: + """Get current environment state.""" + return self._state + + def _make_observation(self) -> AtariObservation: + """ + Create an AtariObservation from current ALE state. + + Returns: + AtariObservation for the agent. + """ + # Get screen observation + if self.obs_type == "rgb": + screen = self.ale.getScreenRGB() + elif self.obs_type == "grayscale": + screen = self.ale.getScreenGrayscale() + elif self.obs_type == "ram": + screen = self.ale.getRAM() + else: + raise ValueError(f"Invalid obs_type: {self.obs_type}") + + # Flatten screen for JSON serialization + # Handle both numpy arrays and lists + if hasattr(screen, "flatten"): + screen_flat = screen.flatten().tolist() + elif hasattr(screen, "tolist"): + screen_flat = screen.tolist() + else: + screen_flat = list(screen) + + # Get game info + lives = self.ale.lives() + episode_frame_number = self.ale.getEpisodeFrameNumber() + frame_number = self.ale.getFrameNumber() + done = self.ale.game_over() + + # Create legal actions list (indices into action_set) + legal_actions = list(range(len(self._action_set))) + + # Create observation + obs = AtariObservation( + screen=screen_flat, + screen_shape=self.screen_shape, + legal_actions=legal_actions, + lives=lives, + episode_frame_number=episode_frame_number, + frame_number=frame_number, + done=done, + reward=0.0, # Will be filled in by step() + metadata={ + "game_name": self.game_name, + "action_meanings": [str(a) for a in self._action_set], + }, + ) + + return obs diff --git a/src/envs/atari_env/test_atari_docker.sh b/src/envs/atari_env/test_atari_docker.sh new file mode 100755 index 00000000..34fa98cc --- /dev/null +++ b/src/envs/atari_env/test_atari_docker.sh @@ -0,0 +1,333 @@ +#!/bin/bash +# Comprehensive Docker test for Atari environment +# Tests: Build, Start, Health, Reset, Step, State, Cleanup + +set -e # Exit on error + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Configuration +IMAGE_NAME="atari-env" +IMAGE_TAG="test" +CONTAINER_NAME="atari-env-test" +PORT="8765" # Use non-standard port to avoid conflicts +HEALTH_RETRIES=30 +HEALTH_DELAY=2 + +# Cleanup function +cleanup() { + echo -e "\n${BLUE}Cleaning up...${NC}" + docker stop ${CONTAINER_NAME} 2>/dev/null || true + docker rm ${CONTAINER_NAME} 2>/dev/null || true + echo -e "${GREEN}✓${NC} Cleanup complete" +} + +# Set trap to cleanup on exit +trap cleanup EXIT + +# Header +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo " ATARI ENVIRONMENT DOCKER TEST" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "" + +# Check prerequisites +echo -e "${BLUE}Checking prerequisites...${NC}" +if ! command -v docker &> /dev/null; then + echo -e "${RED}✗${NC} Docker is not installed" + exit 1 +fi +echo -e "${GREEN}✓${NC} Docker is installed" + +if ! command -v curl &> /dev/null; then + echo -e "${RED}✗${NC} curl is not installed" + exit 1 +fi +echo -e "${GREEN}✓${NC} curl is installed" + +# Check if we're in the right directory +if [ ! -f "src/envs/atari_env/server/Dockerfile" ]; then + echo -e "${RED}✗${NC} Must run from OpenEnv root directory" + exit 1 +fi +echo -e "${GREEN}✓${NC} In correct directory" + +# Step 1: Build Docker image +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 1: Building Docker Image${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +echo "Building ${IMAGE_NAME}:${IMAGE_TAG}..." +if docker build -f src/envs/atari_env/server/Dockerfile -t ${IMAGE_NAME}:${IMAGE_TAG} . 2>&1 | tee /tmp/atari_build.log | tail -n 20; then + echo -e "${GREEN}✓${NC} Docker image built successfully" +else + echo -e "${RED}✗${NC} Docker build failed" + echo "See /tmp/atari_build.log for full output" + exit 1 +fi + +# Check image exists +if docker image inspect ${IMAGE_NAME}:${IMAGE_TAG} &> /dev/null; then + IMAGE_SIZE=$(docker image inspect ${IMAGE_NAME}:${IMAGE_TAG} --format='{{.Size}}' | awk '{print $1/1024/1024}') + echo -e "${GREEN}✓${NC} Image size: ${IMAGE_SIZE} MB" +else + echo -e "${RED}✗${NC} Image not found after build" + exit 1 +fi + +# Step 2: Start container +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 2: Starting Container${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +# Clean up any existing container +docker rm -f ${CONTAINER_NAME} 2>/dev/null || true + +echo "Starting container on port ${PORT}..." +docker run -d \ + --name ${CONTAINER_NAME} \ + -p ${PORT}:8000 \ + -e ATARI_GAME=pong \ + -e ATARI_OBS_TYPE=ram \ + -e ATARI_FRAMESKIP=4 \ + ${IMAGE_NAME}:${IMAGE_TAG} + +if [ $? -eq 0 ]; then + echo -e "${GREEN}✓${NC} Container started: ${CONTAINER_NAME}" +else + echo -e "${RED}✗${NC} Failed to start container" + exit 1 +fi + +# Wait for container to be running +sleep 2 +if docker ps | grep -q ${CONTAINER_NAME}; then + echo -e "${GREEN}✓${NC} Container is running" +else + echo -e "${RED}✗${NC} Container is not running" + docker logs ${CONTAINER_NAME} + exit 1 +fi + +# Step 3: Wait for health check +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 3: Waiting for Server${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +echo "Waiting for server to be ready (timeout: ${HEALTH_RETRIES}s)..." +for i in $(seq 1 ${HEALTH_RETRIES}); do + if curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then + echo -e "${GREEN}✓${NC} Server is ready (${i}s)" + break + fi + + if [ $i -eq ${HEALTH_RETRIES} ]; then + echo -e "${RED}✗${NC} Server did not become ready in time" + echo "Container logs:" + docker logs ${CONTAINER_NAME} + exit 1 + fi + + echo -n "." + sleep ${HEALTH_DELAY} +done + +# Step 4: Test health endpoint +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 4: Testing Health Endpoint${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +HEALTH_RESPONSE=$(curl -s http://localhost:${PORT}/health) +echo "Response: ${HEALTH_RESPONSE}" + +if echo "${HEALTH_RESPONSE}" | grep -q "healthy"; then + echo -e "${GREEN}✓${NC} Health endpoint working" +else + echo -e "${RED}✗${NC} Health endpoint failed" + exit 1 +fi + +# Step 5: Test reset endpoint +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 5: Testing Reset Endpoint${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +RESET_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/reset -H "Content-Type: application/json" -d '{}') + +if [ -z "${RESET_RESPONSE}" ]; then + echo -e "${RED}✗${NC} Reset endpoint returned empty response" + docker logs ${CONTAINER_NAME} | tail -20 + exit 1 +fi + +echo "Response (first 200 chars): ${RESET_RESPONSE:0:200}..." + +# Check if response contains expected fields +if echo "${RESET_RESPONSE}" | grep -q "observation" && \ + echo "${RESET_RESPONSE}" | grep -q "screen" && \ + echo "${RESET_RESPONSE}" | grep -q "legal_actions"; then + echo -e "${GREEN}✓${NC} Reset endpoint working" + + # Extract some info + SCREEN_LEN=$(echo "${RESET_RESPONSE}" | grep -o '"screen":\[[^]]*\]' | wc -c) + echo " Screen data length: ${SCREEN_LEN} chars" +else + echo -e "${RED}✗${NC} Reset response missing required fields" + echo "Full response: ${RESET_RESPONSE}" + exit 1 +fi + +# Step 6: Test step endpoint +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 6: Testing Step Endpoint${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +STEP_PAYLOAD='{"action": {"action_id": 0, "game_name": "pong"}}' +STEP_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/step -H "Content-Type: application/json" -d "${STEP_PAYLOAD}") + +if [ -z "${STEP_RESPONSE}" ]; then + echo -e "${RED}✗${NC} Step endpoint returned empty response" + docker logs ${CONTAINER_NAME} | tail -20 + exit 1 +fi + +echo "Response (first 200 chars): ${STEP_RESPONSE:0:200}..." + +# Check if response contains expected fields +if echo "${STEP_RESPONSE}" | grep -q "observation" && \ + echo "${STEP_RESPONSE}" | grep -q "reward" && \ + echo "${STEP_RESPONSE}" | grep -q "done"; then + echo -e "${GREEN}✓${NC} Step endpoint working" + + # Extract reward and done + REWARD=$(echo "${STEP_RESPONSE}" | grep -o '"reward":[^,}]*' | cut -d: -f2) + DONE=$(echo "${STEP_RESPONSE}" | grep -o '"done":[^,}]*' | cut -d: -f2) + echo " Reward: ${REWARD}" + echo " Done: ${DONE}" +else + echo -e "${RED}✗${NC} Step response missing required fields" + echo "Full response: ${STEP_RESPONSE}" + exit 1 +fi + +# Step 7: Test state endpoint +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 7: Testing State Endpoint${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +STATE_RESPONSE=$(curl -s http://localhost:${PORT}/state) + +if [ -z "${STATE_RESPONSE}" ]; then + echo -e "${RED}✗${NC} State endpoint returned empty response" + docker logs ${CONTAINER_NAME} | tail -20 + exit 1 +fi + +echo "Response: ${STATE_RESPONSE}" + +# Check if response contains expected fields +if echo "${STATE_RESPONSE}" | grep -q "episode_id" && \ + echo "${STATE_RESPONSE}" | grep -q "step_count" && \ + echo "${STATE_RESPONSE}" | grep -q "game_name"; then + echo -e "${GREEN}✓${NC} State endpoint working" + + # Extract info + GAME_NAME=$(echo "${STATE_RESPONSE}" | grep -o '"game_name":"[^"]*"' | cut -d'"' -f4) + STEP_COUNT=$(echo "${STATE_RESPONSE}" | grep -o '"step_count":[^,}]*' | cut -d: -f2) + echo " Game: ${GAME_NAME}" + echo " Steps: ${STEP_COUNT}" +else + echo -e "${RED}✗${NC} State response missing required fields" + echo "Full response: ${STATE_RESPONSE}" + exit 1 +fi + +# Step 8: Test multiple steps +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 8: Testing Multiple Steps${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +echo "Taking 10 steps..." +TOTAL_REWARD=0 +for i in {1..10}; do + ACTION_ID=$((RANDOM % 3)) # Random action 0-2 + STEP_PAYLOAD="{\"action\": {\"action_id\": ${ACTION_ID}, \"game_name\": \"pong\"}}" + STEP_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/step -H "Content-Type: application/json" -d "${STEP_PAYLOAD}") + + if ! echo "${STEP_RESPONSE}" | grep -q "observation"; then + echo -e "${RED}✗${NC} Step ${i} failed" + exit 1 + fi + + REWARD=$(echo "${STEP_RESPONSE}" | grep -o '"reward":[^,}]*' | cut -d: -f2 | sed 's/null/0/') + DONE=$(echo "${STEP_RESPONSE}" | grep -o '"done":[^,}]*' | cut -d: -f2) + + echo " Step ${i}: action=${ACTION_ID}, reward=${REWARD}, done=${DONE}" + + if [ "${DONE}" = "true" ]; then + echo " Episode completed early at step ${i}" + break + fi +done + +echo -e "${GREEN}✓${NC} Multiple steps completed successfully" + +# Step 9: Check container logs for errors +echo "" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BLUE}STEP 9: Checking Container Logs${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +LOGS=$(docker logs ${CONTAINER_NAME} 2>&1) + +if echo "${LOGS}" | grep -i "error" | grep -v "LoggerMode.Error"; then + echo -e "${YELLOW}⚠${NC} Found errors in logs:" + echo "${LOGS}" | grep -i "error" | head -5 +else + echo -e "${GREEN}✓${NC} No errors in container logs" +fi + +if echo "${LOGS}" | grep -i "exception"; then + echo -e "${RED}✗${NC} Found exceptions in logs:" + echo "${LOGS}" | grep -i "exception" | head -5 + exit 1 +else + echo -e "${GREEN}✓${NC} No exceptions in container logs" +fi + +# Final Summary +echo "" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo -e "${GREEN}✅ ALL DOCKER TESTS PASSED${NC}" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "" +echo "Summary:" +echo " ✓ Docker image built successfully" +echo " ✓ Container started and ran" +echo " ✓ Health endpoint working" +echo " ✓ Reset endpoint working" +echo " ✓ Step endpoint working" +echo " ✓ State endpoint working" +echo " ✓ Multiple steps working" +echo " ✓ No errors or exceptions" +echo "" +echo "Image: ${IMAGE_NAME}:${IMAGE_TAG}" +echo "Container: ${CONTAINER_NAME}" +echo "Port: ${PORT}" +echo "" +echo "To keep container running: docker start ${CONTAINER_NAME}" +echo "To view logs: docker logs ${CONTAINER_NAME}" +echo ""