From 0b15865bb3b5e8ac4decbacb58066199d6d5b514 Mon Sep 17 00:00:00 2001 From: Zach Wentz Date: Fri, 17 Oct 2025 22:34:18 -0400 Subject: [PATCH 01/11] Test workflow From e04a79a12c2188dfccefd3f38f699dfe1e9c946b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 17 Nov 2025 05:25:47 +0000 Subject: [PATCH 02/11] refactor: migrate from dataclasses to Pydantic models --- src/core/env_server/http_server.py | 84 +- src/core/env_server/types.py | 114 +- src/core/env_server/web_interface.py | 3311 +++++++++++++------------- 3 files changed, 1825 insertions(+), 1684 deletions(-) diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 207235f6..5a0daba2 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -16,12 +16,14 @@ import asyncio import os from concurrent.futures import ThreadPoolExecutor -from dataclasses import asdict -from typing import Any, Dict, Type +from typing import Any, Dict, Type, Optional + +from pydantic import ValidationError +from fastapi import Body, FastAPI, HTTPException, status from .interfaces import Environment from .types import Action, Observation -from fastapi import Body, FastAPI + class HTTPEnvServer: """ @@ -95,8 +97,14 @@ async def step(request: Dict[str, Any]) -> Dict[str, Any]: action_data = request.get("action", request) # TODO: Handle timeout_s, request_id, episode_id from request if provided - # Deserialize action - action = self._deserialize_action(action_data) + # Deserialize action with Pydantic validation + try: + action = self._deserialize_action(action_data) + except ValidationError as e: + # Return HTTP 422 with detailed validation errors + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors() + ) # Execute step in thread pool to avoid blocking asyncio loop loop = asyncio.get_event_loop() @@ -111,17 +119,16 @@ async def step(request: Dict[str, Any]) -> Dict[str, Any]: async def get_state() -> Dict[str, Any]: """State endpoint - returns current environment state.""" state = self.env.state - return asdict(state) + return state.model_dump() @app.get("/health") async def health() -> Dict[str, str]: """Health check endpoint.""" return {"status": "healthy"} - def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: """ - Convert JSON dict to Action instance. + Convert JSON dict to Action instance using Pydantic validation. Args: action_data: Dictionary containing action data @@ -129,19 +136,19 @@ def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: Returns: Action instance + Raises: + ValidationError: If action_data is invalid for the action class + Note: - This is a simple implementation. Subclasses may need to override - for more complex deserialization logic. + This uses Pydantic's model_validate() for automatic validation. """ - # Remove metadata if present (it will be set via kw_only field) - metadata = action_data.pop("metadata", {}) - action = self.action_cls(**action_data) - action.metadata = metadata + # Pydantic handles validation automatically + action = self.action_cls.model_validate(action_data) return action def _serialize_observation(self, observation: Observation) -> Dict[str, Any]: """ - Convert Observation instance to JSON-compatible dict. + Convert Observation instance to JSON-compatible dict using Pydantic. Args: observation: Observation instance @@ -156,25 +163,18 @@ def _serialize_observation(self, observation: Observation) -> Dict[str, Any]: "done": bool, } """ - obs_dict = asdict(observation) - - # Convert numpy arrays to lists for JSON serialization - def _convert_numpy(obj): - """Recursively convert numpy arrays to lists.""" - if hasattr(obj, '__array__'): # numpy array - return obj.tolist() - elif isinstance(obj, dict): - return {k: _convert_numpy(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple)): - return type(obj)(_convert_numpy(item) for item in obj) - return obj - - obs_dict = _convert_numpy(obs_dict) - - # Extract reward and done (these are part of StepResult on client side) - reward = obs_dict.pop("reward", None) - done = obs_dict.pop("done", False) - obs_dict.pop("metadata", None) # Remove metadata from observation + # Use Pydantic's model_dump() for serialization + obs_dict = observation.model_dump( + exclude={ + "reward", + "done", + "metadata", + } # Exclude these from observation dict + ) + + # Extract reward and done directly from the observation + reward = observation.reward + done = observation.done # Return in HTTPEnvClient expected format return { @@ -183,6 +183,7 @@ def _convert_numpy(obj): "done": done, } + def create_app( env: Environment, action_cls: Type[Action], @@ -191,33 +192,36 @@ def create_app( ) -> Any: """ Create a FastAPI application with or without web interface. - + This function creates a FastAPI app with the web interface enabled by default, including README integration for better user experience. - + Args: env: The Environment instance to serve action_cls: The Action subclass this environment expects observation_cls: The Observation subclass this environment returns env_name: Optional environment name for README loading - + Returns: FastAPI application instance with or without web interface and README integration """ # Check if web interface should be enabled # This can be controlled via environment variable or build argument - enable_web = ( - os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes") + enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ( + "true", + "1", + "yes", ) if enable_web: # Import web interface only when needed from .web_interface import create_web_interface_app + return create_web_interface_app(env, action_cls, observation_cls, env_name) else: # Use standard FastAPI app without web interface return create_fastapi_app(env, action_cls, observation_cls) - + def create_fastapi_app( env: Environment, diff --git a/src/core/env_server/types.py b/src/core/env_server/types.py index 70da9f3c..2a3256d5 100644 --- a/src/core/env_server/types.py +++ b/src/core/env_server/types.py @@ -4,54 +4,106 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union +from pydantic import BaseModel, Field, ConfigDict # Type aliases Scalar = Union[int, float, bool] -@dataclass(kw_only=True) -class Action: - """Base class for all environment actions.""" +class Action(BaseModel): + """Base class for all environment actions. - metadata: Dict[str, Any] = field(default_factory=dict) + All action subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + model_config = ConfigDict( + extra="forbid", # Reject unknown fields + validate_assignment=True, # Validate on field assignment + arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc. + ) -@dataclass(kw_only=True) -class Observation: - """Base class for all environment observations.""" + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the action" + ) - done: bool = False - reward: Union[bool, int, float, None] = None - metadata: Dict[str, Any] = field(default_factory=dict) +class Observation(BaseModel): + """Base class for all environment observations. -@dataclass -class State: - """Base class for environment state.""" + All observation subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ - episode_id: Optional[str] = None - step_count: int = 0 + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + arbitrary_types_allowed=True, + ) + done: bool = Field(default=False, description="Whether the episode has terminated") + reward: Union[bool, int, float, None] = Field( + default=None, description="Reward signal from the last action" + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the observation" + ) -@dataclass -class CodeExecResult: + +class State(BaseModel): + """Base class for environment state. + + Represents internal environment state, separate from observations. + """ + + model_config = ConfigDict( + extra="allow", # Allow extra fields for flexibility + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + episode_id: Optional[str] = Field( + default=None, description="Unique identifier for the current episode" + ) + step_count: int = Field( + default=0, + ge=0, # Greater than or equal to 0 + description="Number of steps taken in the current episode", + ) + + +class CodeExecResult(BaseModel): """Result of code execution containing stdout, stderr, and exit code.""" - stdout: str - stderr: str - exit_code: int + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + stdout: str = Field(description="Standard output from code execution") + stderr: str = Field(description="Standard error from code execution") + exit_code: int = Field(description="Exit code from code execution") -@dataclass -class EnvironmentMetadata: + +class EnvironmentMetadata(BaseModel): """Metadata about an environment for documentation and UI purposes.""" - - name: str - description: str - readme_content: Optional[str] = None - version: Optional[str] = None - author: Optional[str] = None - documentation_url: Optional[str] = None + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + name: str = Field(description="Name of the environment") + description: str = Field(description="Description of what the environment does") + readme_content: Optional[str] = Field( + default=None, description="Content of the README file for the environment" + ) + version: Optional[str] = Field( + default=None, description="Version of the environment" + ) + author: Optional[str] = Field(default=None, description="Author of the environment") + documentation_url: Optional[str] = Field( + default=None, description="URL to the environment's documentation" + ) diff --git a/src/core/env_server/web_interface.py b/src/core/env_server/web_interface.py index 3c36aa1d..c9f899a5 100644 --- a/src/core/env_server/web_interface.py +++ b/src/core/env_server/web_interface.py @@ -1,1613 +1,1698 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Web interface for OpenEnv environments. - -This module provides a web-based interface for interacting with OpenEnv environments, -including a two-pane layout for HumanAgent interaction and state observation. -""" - -from __future__ import annotations - -import json -import time -from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Type -from datetime import datetime - -from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request -from fastapi.responses import HTMLResponse, FileResponse -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel - -from .interfaces import Environment -from .types import Action, Observation, State, EnvironmentMetadata - - -def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata: - """ - Load environment metadata including README content. - - Args: - env: The environment instance - env_name: Optional environment name for README file lookup - - Returns: - EnvironmentMetadata with loaded information - """ - # Try to get metadata from environment if it has a method for it - if hasattr(env, 'get_metadata'): - return env.get_metadata() - - # Default metadata - metadata = EnvironmentMetadata( - name=env_name or env.__class__.__name__, - description=f"{env.__class__.__name__} environment", - version="1.0.0" - ) - - # Try to load README from file system - readme_content = _load_readme_from_filesystem(env_name) - if readme_content: - metadata.readme_content = readme_content - - return metadata - - -def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: - """ - Load README content from the filesystem. - - Tries multiple locations: - 1. Container filesystem: /app/README.md - 2. Local development: src/envs/{env_name}/README.md - 3. Environment variable: ENV_README_PATH - """ - import os - from pathlib import Path - - # Try container filesystem first - container_readme = Path("/app/README.md") - if container_readme.exists(): - try: - return container_readme.read_text(encoding='utf-8') - except Exception: - pass - - # Try environment variable path - custom_path = os.environ.get("ENV_README_PATH") - if custom_path and Path(custom_path).exists(): - try: - return Path(custom_path).read_text(encoding='utf-8') - except Exception: - pass - - # Try local development path - if env_name: - local_readme = Path(f"src/envs/{env_name}/README.md") - if local_readme.exists(): - try: - return local_readme.read_text(encoding='utf-8') - except Exception: - pass - - return None - - -@dataclass -class ActionLog: - """Log entry for an action taken.""" - timestamp: str - action: Dict[str, Any] - observation: Dict[str, Any] - reward: Optional[float] - done: bool - step_count: int - - -@dataclass -class EpisodeState: - """Current episode state for the web interface.""" - episode_id: Optional[str] - step_count: int - current_observation: Optional[Dict[str, Any]] - action_logs: List[ActionLog] - is_reset: bool = True - - -class WebInterfaceManager: - """Manages the web interface for an environment.""" - - def __init__( - self, - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - metadata: Optional[EnvironmentMetadata] = None, - ): - self.env = env - self.action_cls = action_cls - self.observation_cls = observation_cls - self.metadata = metadata or EnvironmentMetadata( - name=env.__class__.__name__, - description=f"{env.__class__.__name__} environment" - ) - self.episode_state = EpisodeState( - episode_id=None, - step_count=0, - current_observation=None, - action_logs=[] - ) - self.connected_clients: List[WebSocket] = [] - - async def connect_websocket(self, websocket: WebSocket): - """Connect a new WebSocket client.""" - await websocket.accept() - self.connected_clients.append(websocket) - - # Send current state to the new client - await self._send_state_update() - - async def disconnect_websocket(self, websocket: WebSocket): - """Disconnect a WebSocket client.""" - if websocket in self.connected_clients: - self.connected_clients.remove(websocket) - - async def _send_state_update(self): - """Send current state to all connected clients.""" - if not self.connected_clients: - return - - state_data = { - "type": "state_update", - "episode_state": asdict(self.episode_state) - } - - # Send to all connected clients - disconnected_clients = [] - for client in self.connected_clients: - try: - await client.send_text(json.dumps(state_data)) - except: - disconnected_clients.append(client) - - # Remove disconnected clients - for client in disconnected_clients: - self.connected_clients.remove(client) - - async def reset_environment(self) -> Dict[str, Any]: - """Reset the environment and update state.""" - observation = self.env.reset() - state = self.env.state - - # Update episode state - self.episode_state.episode_id = state.episode_id - self.episode_state.step_count = 0 - self.episode_state.current_observation = asdict(observation) - self.episode_state.action_logs = [] - self.episode_state.is_reset = True - - # Send state update - await self._send_state_update() - - return { - "observation": asdict(observation), - "reward": observation.reward, - "done": observation.done, - } - - async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: - """Execute a step in the environment and update state.""" - # Deserialize action - action = self._deserialize_action(action_data) - - # Execute step - observation = self.env.step(action) - state = self.env.state - - # Create action log - action_log = ActionLog( - timestamp=datetime.now().isoformat(), - action=asdict(action), - observation=asdict(observation), - reward=observation.reward, - done=observation.done, - step_count=state.step_count - ) - - # Update episode state - self.episode_state.episode_id = state.episode_id - self.episode_state.step_count = state.step_count - self.episode_state.current_observation = asdict(observation) - self.episode_state.action_logs.append(action_log) - self.episode_state.is_reset = False - - # Send state update - await self._send_state_update() - - return { - "observation": asdict(observation), - "reward": observation.reward, - "done": observation.done, - } - - def get_state(self) -> Dict[str, Any]: - """Get current environment state.""" - state = self.env.state - return asdict(state) - - def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: - """Convert JSON dict to Action instance.""" - metadata = action_data.pop("metadata", {}) - - # Handle tensor fields that come from JSON as lists - processed_data = {} - for key, value in action_data.items(): - if key == "tokens" and isinstance(value, (list, str)): - # Convert list or string to tensor - if isinstance(value, str): - # If it's a string, try to parse it as a list of numbers - try: - import json - value = json.loads(value) - except: - # If parsing fails, treat as empty list - value = [] - if isinstance(value, list): - import torch - processed_data[key] = torch.tensor(value, dtype=torch.long) - else: - processed_data[key] = value - elif key == "action_id" and isinstance(value, str): - # Convert action_id from string to int - try: - processed_data[key] = int(value) - except ValueError: - # If conversion fails, keep original value - processed_data[key] = value - else: - processed_data[key] = value - - action = self.action_cls(**processed_data) - action.metadata = metadata - return action - - -def create_web_interface_app( - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - env_name: Optional[str] = None, -) -> FastAPI: - """ - Create a FastAPI application with web interface for the given environment. - - Args: - env: The Environment instance to serve - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - env_name: Optional environment name for README loading - - Returns: - FastAPI application instance with web interface - """ - from .http_server import create_fastapi_app - - # Create the base environment app - app = create_fastapi_app(env, action_cls, observation_cls) - - # Load environment metadata - metadata = load_environment_metadata(env, env_name) - - # Create web interface manager - web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) - - # Add web interface routes - @app.get("/web", response_class=HTMLResponse) - async def web_interface(): - """Serve the web interface.""" - return get_web_interface_html(action_cls, web_manager.metadata) - - @app.get("/web/metadata") - async def web_metadata(): - """Get environment metadata.""" - return asdict(web_manager.metadata) - - @app.websocket("/ws") - async def websocket_endpoint(websocket: WebSocket): - """WebSocket endpoint for real-time updates.""" - await web_manager.connect_websocket(websocket) - try: - while True: - # Keep connection alive - await websocket.receive_text() - except WebSocketDisconnect: - await web_manager.disconnect_websocket(websocket) - - @app.post("/web/reset") - async def web_reset(): - """Reset endpoint for web interface.""" - return await web_manager.reset_environment() - - @app.post("/web/step") - async def web_step(request: Dict[str, Any]): - """Step endpoint for web interface.""" - # Check if this is a message-based request (chat environment) - if "message" in request: - message = request["message"] - # Convert message to action using the environment's message_to_action method - action = web_manager.env.message_to_action(message) - action_data = {"tokens": action.tokens.tolist()} - else: - action_data = request.get("action", {}) - - return await web_manager.step_environment(action_data) - - @app.get("/web/state") - async def web_state(): - """State endpoint for web interface.""" - return web_manager.get_state() - - return app - - -def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str: - """Generate the HTML for the web interface.""" - - # Check if this is a chat environment by looking for tokens field - is_chat_env = False - if hasattr(action_cls, '__dataclass_fields__'): - for field_name, field_info in action_cls.__dataclass_fields__.items(): - if field_name == 'tokens' and hasattr(field_info.type, '__name__') and 'Tensor' in field_info.type.__name__: - is_chat_env = True - break - - # Get action fields for dynamic form generation with enhanced metadata - action_fields = _extract_action_fields(action_cls) - - return f""" - - - - - - OpenEnv Web Interface - - - -
- -
-
- - HumanAgent Interface -
-
- - {_generate_instructions_section(metadata)} - - - {_generate_action_interface(action_fields, is_chat_env)} - - -
- - -
- - -
-

Current State

-
-
- Status: - Not initialized -
-
- Episode ID: - - -
-
- Step Count: - 0 -
-
-
-
-
- - -
-
- State Observer -
-
- -
-

Current Observation

-
- No observation yet -
-
- - -
-

Action History

-
- No actions taken yet -
-
-
-
-
- - - - - """.replace('{_generate_action_form_fields(action_fields)}', _generate_action_form_fields(action_fields)) - - -def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> str: - """Generate the instructions section with environment documentation.""" - if not metadata or not metadata.readme_content: - return '' - - # Convert markdown to HTML (basic conversion) - import re - html_content = _markdown_to_html(metadata.readme_content) - - return f''' - -
-
-

{metadata.name}

- -
-
-
- {html_content} -
-
-
- ''' - - -def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: - """Extract enhanced field metadata from Action class for form generation.""" - import typing - from typing import get_origin, get_args - - action_fields = [] - if not hasattr(action_cls, '__dataclass_fields__'): - return action_fields - - for field_name, field_info in action_cls.__dataclass_fields__.items(): - if field_name == 'metadata': - continue - - field_type = field_info.type - field_metadata = _extract_field_metadata(field_name, field_info) - - # Determine input type based on field type - input_type = _determine_input_type(field_type) - - # Check if field is required - is_required = field_info.default is field_info.default_factory - - action_fields.append({ - 'name': field_name, - 'type': input_type, - 'required': is_required, - 'description': field_metadata.get('description', ''), - 'default_value': field_metadata.get('default_value'), - 'choices': field_metadata.get('choices', []), - 'min_value': field_metadata.get('min_value'), - 'max_value': field_metadata.get('max_value'), - 'placeholder': field_metadata.get('placeholder', ''), - 'help_text': field_metadata.get('help_text', ''), - }) - - return action_fields - - -def _extract_field_metadata(field_name: str, field_info) -> Dict[str, Any]: - """Extract metadata from dataclass field including docstring and type hints.""" - import typing - from typing import get_origin, get_args, Literal, Union, Optional - - metadata = {} - - # Extract description from field docstring or annotation - if hasattr(field_info, 'metadata') and field_info.metadata: - # Check for custom metadata - for meta in field_info.metadata: - if isinstance(meta, dict): - metadata.update(meta) - - # Extract type information - field_type = field_info.type - origin = get_origin(field_type) - - # Handle Literal types for dropdown choices - if origin is Literal: - args = get_args(field_type) - metadata['choices'] = list(args) - - # Handle Optional types - if origin is Union: - args = get_args(field_type) - if len(args) == 2 and type(None) in args: - # This is Optional[SomeType] - non_none_type = args[0] if args[1] is type(None) else args[1] - metadata['optional'] = True - # Recursively check the non-None type for choices - if get_origin(non_none_type) is Literal: - metadata['choices'] = list(get_args(non_none_type)) - else: - # Regular Union type - metadata['choices'] = [str(arg) for arg in args if arg is not type(None)] - - # Handle numeric constraints - if field_type in (int, float): - # Check for common constraint patterns in field name - if 'count' in field_name.lower() or 'num' in field_name.lower(): - metadata['min_value'] = 0 - if 'id' in field_name.lower(): - metadata['min_value'] = 0 - - # Generate placeholder text - if 'message' in field_name.lower(): - metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...' - elif 'code' in field_name.lower(): - metadata['placeholder'] = 'Enter Python code here...' - elif 'tokens' in field_name.lower(): - metadata['placeholder'] = 'Enter comma-separated token IDs (e.g., 1,2,3,4,5)' - else: - metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...' - - # Generate help text based on field name and type - if 'action_id' in field_name.lower(): - metadata['help_text'] = 'The action ID to execute in the environment' - elif 'game_name' in field_name.lower(): - metadata['help_text'] = 'Name of the game or environment' - elif 'tokens' in field_name.lower(): - metadata['help_text'] = 'Token IDs as a comma-separated list of integers' - elif 'code' in field_name.lower(): - metadata['help_text'] = 'Python code to execute in the environment' - elif 'message' in field_name.lower(): - metadata['help_text'] = 'Text message to send' - - return metadata - - -def _determine_input_type(field_type) -> str: - """Determine the appropriate HTML input type for a field type.""" - import typing - from typing import get_origin, get_args, Literal, Union - - # Handle direct types - if field_type == str: - return "text" - elif field_type == int: - return "number" - elif field_type == float: - return "number" - elif field_type == bool: - return "checkbox" - - # Handle complex types - origin = get_origin(field_type) - - if origin is Literal: - return "select" - elif origin is Union: - args = get_args(field_type) - if len(args) == 2 and type(None) in args: - # Optional type - use the non-None type - non_none_type = args[0] if args[1] is type(None) else args[1] - return _determine_input_type(non_none_type) - elif all(isinstance(arg, str) for arg in args if arg is not type(None)): - return "select" - else: - return "text" - elif hasattr(field_type, '__name__') and 'Tensor' in field_type.__name__: - return "tensor" - else: - return "text" - - -def _markdown_to_html(markdown: str) -> str: - """Convert basic markdown to HTML for README display.""" - import html - import re - - # Escape HTML first - html_content = html.escape(markdown) - - # Convert headers - html_content = re.sub(r'^# (.*?)$', r'

\1

', html_content, flags=re.MULTILINE) - html_content = re.sub(r'^## (.*?)$', r'

\1

', html_content, flags=re.MULTILINE) - html_content = re.sub(r'^### (.*?)$', r'

\1

', html_content, flags=re.MULTILINE) - - # Convert code blocks - html_content = re.sub(r'```(.*?)\n(.*?)\n```', r'
\2
', html_content, flags=re.DOTALL) - html_content = re.sub(r'`([^`]+)`', r'\1', html_content) - - # Convert bold and italic - html_content = re.sub(r'\*\*(.*?)\*\*', r'\1', html_content) - html_content = re.sub(r'\*(.*?)\*', r'\1', html_content) - - # Convert lists - html_content = re.sub(r'^- (.*?)$', r'
  • \1
  • ', html_content, flags=re.MULTILINE) - html_content = re.sub(r'(
  • .*
  • )', r'', html_content, flags=re.DOTALL) - - # Convert line breaks - html_content = html_content.replace('\n', '
    ') - - return html_content - - -def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str: - """Generate either a chat interface or action form based on environment type.""" - if is_chat_env: - return _generate_chat_interface() - else: - return _generate_action_form(action_fields) - -def _generate_chat_interface() -> str: - """Generate a chat-style interface for chat environments.""" - return ''' - -
    -

    Chat Interface

    -
    -
    -
    System
    -
    Chat environment ready. Send a message to start the conversation.
    -
    -
    -
    -
    - - -
    -
    - - -
    -
    -
    - ''' - -def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str: - """Generate a traditional action form for non-chat environments.""" - return f''' - -
    -

    Take Action

    -
    - {_generate_action_form_fields(action_fields)} - -
    -
    - ''' - -def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str: - """Generate HTML form fields for action input with enhanced metadata.""" - if not action_fields: - return '

    No action fields available

    ' - - fields_html = [] - for field in action_fields: - field_html = _generate_single_field(field) - fields_html.append(field_html) - - return '\n'.join(fields_html) - - -def _generate_single_field(field: Dict[str, Any]) -> str: - """Generate HTML for a single form field with enhanced metadata.""" - field_name = field['name'] - field_type = field['type'] - required = field['required'] - placeholder = field.get('placeholder', '') - help_text = field.get('help_text', '') - choices = field.get('choices', []) - min_value = field.get('min_value') - max_value = field.get('max_value') - default_value = field.get('default_value') - - # Build label with required indicator - label_text = field_name.replace('_', ' ').title() - if required: - label_text += ' *' - - # Build input attributes - input_attrs = [] - if required: - input_attrs.append('required') - if placeholder: - input_attrs.append(f'placeholder="{placeholder}"') - if min_value is not None: - input_attrs.append(f'min="{min_value}"') - if max_value is not None: - input_attrs.append(f'max="{max_value}"') - if default_value is not None: - input_attrs.append(f'value="{default_value}"') - - attrs_str = ' '.join(input_attrs) - - if field_type == 'checkbox': - return f''' -
    - - {f'{help_text}' if help_text else ''} -
    - ''' - - elif field_type == 'select': - options_html = [] - if not required: - options_html.append(f'') - - for choice in choices: - selected = 'selected' if str(choice) == str(default_value) else '' - options_html.append(f'') - - return f''' -
    - - - {f'{help_text}' if help_text else ''} -
    - ''' - - elif field_type == 'tensor': - return f''' -
    - - - {help_text or 'Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)'} -
    - ''' - - elif field_type == 'text' and ('message' in field_name.lower() or 'code' in field_name.lower()): - return f''' -
    - - - {f'{help_text}' if help_text else ''} -
    - ''' - - else: - return f''' -
    - - - {f'{help_text}' if help_text else ''} -
    - ''' +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Web interface for OpenEnv environments. + +This module provides a web-based interface for interacting with OpenEnv environments, +including a two-pane layout for HumanAgent interaction and state observation. +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional, Type +from datetime import datetime + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse +from pydantic import BaseModel, Field, ConfigDict + +from .interfaces import Environment +from .types import Action, Observation, State, EnvironmentMetadata + + +def load_environment_metadata( + env: Environment, env_name: Optional[str] = None +) -> EnvironmentMetadata: + """ + Load environment metadata including README content. + + Args: + env: The environment instance + env_name: Optional environment name for README file lookup + + Returns: + EnvironmentMetadata with loaded information + """ + # Try to get metadata from environment if it has a method for it + if hasattr(env, "get_metadata"): + return env.get_metadata() + + # Default metadata + metadata = EnvironmentMetadata( + name=env_name or env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + version="1.0.0", + ) + + # Try to load README from file system + readme_content = _load_readme_from_filesystem(env_name) + if readme_content: + metadata.readme_content = readme_content + + return metadata + + +def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: + """ + Load README content from the filesystem. + + Tries multiple locations: + 1. Container filesystem: /app/README.md + 2. Local development: src/envs/{env_name}/README.md + 3. Environment variable: ENV_README_PATH + """ + import os + from pathlib import Path + + # Try container filesystem first + container_readme = Path("/app/README.md") + if container_readme.exists(): + try: + return container_readme.read_text(encoding="utf-8") + except Exception: + pass + + # Try environment variable path + custom_path = os.environ.get("ENV_README_PATH") + if custom_path and Path(custom_path).exists(): + try: + return Path(custom_path).read_text(encoding="utf-8") + except Exception: + pass + + # Try local development path + if env_name: + local_readme = Path(f"src/envs/{env_name}/README.md") + if local_readme.exists(): + try: + return local_readme.read_text(encoding="utf-8") + except Exception: + pass + + return None + + +class ActionLog(BaseModel): + """Log entry for an action taken.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + timestamp: str = Field(description="Timestamp when action was taken") + action: Dict[str, Any] = Field(description="Action that was taken") + observation: Dict[str, Any] = Field(description="Observation returned from action") + reward: Optional[float] = Field( + default=None, description="Reward received from action" + ) + done: bool = Field(description="Whether the episode is done after this action") + step_count: int = Field(description="Step count when this action was taken") + + +class EpisodeState(BaseModel): + """Current episode state for the web interface.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + episode_id: Optional[str] = Field(default=None, description="Current episode ID") + step_count: int = Field(description="Current step count in episode") + current_observation: Optional[Dict[str, Any]] = Field( + default=None, description="Current observation" + ) + action_logs: List[ActionLog] = Field( + default_factory=list, description="List of action logs" + ) + is_reset: bool = Field( + default=True, description="Whether the episode has been reset" + ) + + +class WebInterfaceManager: + """Manages the web interface for an environment.""" + + def __init__( + self, + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + metadata: Optional[EnvironmentMetadata] = None, + ): + self.env = env + self.action_cls = action_cls + self.observation_cls = observation_cls + self.metadata = metadata or EnvironmentMetadata( + name=env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + ) + self.episode_state = EpisodeState( + episode_id=None, step_count=0, current_observation=None, action_logs=[] + ) + self.connected_clients: List[WebSocket] = [] + + async def connect_websocket(self, websocket: WebSocket): + """Connect a new WebSocket client.""" + await websocket.accept() + self.connected_clients.append(websocket) + + # Send current state to the new client + await self._send_state_update() + + async def disconnect_websocket(self, websocket: WebSocket): + """Disconnect a WebSocket client.""" + if websocket in self.connected_clients: + self.connected_clients.remove(websocket) + + async def _send_state_update(self): + """Send current state to all connected clients.""" + if not self.connected_clients: + return + + state_data = { + "type": "state_update", + "episode_state": self.episode_state.model_dump(), + } + + # Send to all connected clients + disconnected_clients = [] + for client in self.connected_clients: + try: + await client.send_text(json.dumps(state_data)) + except Exception: + disconnected_clients.append(client) + + # Remove disconnected clients + for client in disconnected_clients: + self.connected_clients.remove(client) + + async def reset_environment(self) -> Dict[str, Any]: + """Reset the environment and update state.""" + observation: Observation = self.env.reset() + state: State = self.env.state + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = 0 + self.episode_state.current_observation = observation.model_dump( + exclude={"reward", "done", "metadata"} + ) + self.episode_state.action_logs = [] + self.episode_state.is_reset = True + + # Send state update + await self._send_state_update() + + return { + "observation": observation.model_dump( + exclude={"reward", "done", "metadata"} + ), + "reward": observation.reward, + "done": observation.done, + } + + async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: + """Execute a step in the environment and update state.""" + # Deserialize action + action: Action = self._deserialize_action(action_data) + + # Execute step + observation: Observation = self.env.step(action) + state: State = self.env.state + + # Create action log + action_log = ActionLog( + timestamp=datetime.now().isoformat(), + action=action.model_dump(exclude={"metadata"}), + observation=observation.model_dump(exclude={"reward", "done", "metadata"}), + reward=observation.reward, + done=observation.done, + step_count=state.step_count, + ) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = state.step_count + self.episode_state.current_observation = observation.model_dump( + exclude={"reward", "done", "metadata"} + ) + self.episode_state.action_logs.append(action_log) + self.episode_state.is_reset = False + + # Send state update + await self._send_state_update() + + return { + "observation": observation.model_dump( + exclude={"reward", "done", "metadata"} + ), + "reward": observation.reward, + "done": observation.done, + } + + def get_state(self) -> Dict[str, Any]: + """Get current environment state.""" + state: State = self.env.state + return state.model_dump() + + def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: + """Convert JSON dict to Action instance using Pydantic validation.""" + # Handle tensor fields that come from JSON as lists + processed_data = {} + for key, value in action_data.items(): + if key == "tokens" and isinstance(value, (list, str)): + # Convert list or string to tensor + if isinstance(value, str): + # If it's a string, try to parse it as a list of numbers + try: + import json + + value = json.loads(value) + except Exception: + # If parsing fails, treat as empty list + value = [] + if isinstance(value, list): + import torch + + processed_data[key] = torch.tensor(value, dtype=torch.long) + else: + processed_data[key] = value + elif key == "action_id" and isinstance(value, str): + # Convert action_id from string to int + try: + processed_data[key] = int(value) + except ValueError: + # If conversion fails, keep original value + processed_data[key] = value + else: + processed_data[key] = value + + # Use Pydantic's model_validate for automatic validation + action = self.action_cls.model_validate(processed_data) + return action + + +def create_web_interface_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, +) -> FastAPI: + """ + Create a FastAPI application with web interface for the given environment. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + + Returns: + FastAPI application instance with web interface + """ + from .http_server import create_fastapi_app + + # Create the base environment app + app = create_fastapi_app(env, action_cls, observation_cls) + + # Load environment metadata + metadata = load_environment_metadata(env, env_name) + + # Create web interface manager + web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + + # Add web interface routes + @app.get("/web", response_class=HTMLResponse) + async def web_interface(): + """Serve the web interface.""" + return get_web_interface_html(action_cls, web_manager.metadata) + + @app.get("/web/metadata") + async def web_metadata(): + """Get environment metadata.""" + return web_manager.metadata.model_dump() + + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """WebSocket endpoint for real-time updates.""" + await web_manager.connect_websocket(websocket) + try: + while True: + # Keep connection alive + await websocket.receive_text() + except WebSocketDisconnect: + await web_manager.disconnect_websocket(websocket) + + @app.post("/web/reset") + async def web_reset(): + """Reset endpoint for web interface.""" + return await web_manager.reset_environment() + + @app.post("/web/step") + async def web_step(request: Dict[str, Any]): + """Step endpoint for web interface.""" + # Check if this is a message-based request (chat environment) + if "message" in request: + message = request["message"] + # Convert message to action using the environment's message_to_action method + action = web_manager.env.message_to_action(message) + action_data = {"tokens": action.tokens.tolist()} + else: + action_data = request.get("action", {}) + + return await web_manager.step_environment(action_data) + + @app.get("/web/state") + async def web_state(): + """State endpoint for web interface.""" + return web_manager.get_state() + + return app + + +def get_web_interface_html( + action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None +) -> str: + """Generate the HTML for the web interface.""" + + # Check if this is a chat environment by looking for tokens field + is_chat_env = False + if hasattr(action_cls, "model_fields"): + for field_name, field_info in action_cls.model_fields.items(): + if ( + field_name == "tokens" + and hasattr(field_info.annotation, "__name__") + and "Tensor" in field_info.annotation.__name__ + ): + is_chat_env = True + break + + # Get action fields for dynamic form generation with enhanced metadata + action_fields = _extract_action_fields(action_cls) + + return f""" + + + + + + OpenEnv Web Interface + + + +
    + +
    +
    + + HumanAgent Interface +
    +
    + + {_generate_instructions_section(metadata)} + + + {_generate_action_interface(action_fields, is_chat_env)} + + +
    + + +
    + + +
    +

    Current State

    +
    +
    + Status: + Not initialized +
    +
    + Episode ID: + - +
    +
    + Step Count: + 0 +
    +
    +
    +
    +
    + + +
    +
    + State Observer +
    +
    + +
    +

    Current Observation

    +
    + No observation yet +
    +
    + + +
    +

    Action History

    +
    + No actions taken yet +
    +
    +
    +
    +
    + + + + + """.replace( + "{_generate_action_form_fields(action_fields)}", + _generate_action_form_fields(action_fields), + ) + + +def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> str: + """Generate the instructions section with environment documentation.""" + if not metadata or not metadata.readme_content: + return "" + + html_content = _markdown_to_html(metadata.readme_content) + + return f""" + +
    +
    +

    {metadata.name}

    + +
    +
    +
    + {html_content} +
    +
    +
    + """ + + +def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: + """Extract enhanced field metadata from Action class for form generation.""" + + action_fields = [] + if not hasattr(action_cls, "model_fields"): + return action_fields + + for field_name, field_info in action_cls.model_fields.items(): + if field_name == "metadata": + continue + + field_type = field_info.annotation + field_metadata = _extract_field_metadata(field_name, field_info) + + # Determine input type based on field type + input_type = _determine_input_type(field_type) + + # Check if field is required + is_required = field_info.is_required() + + action_fields.append( + { + "name": field_name, + "type": input_type, + "required": is_required, + "description": field_metadata.get("description", ""), + "default_value": field_metadata.get("default_value"), + "choices": field_metadata.get("choices", []), + "min_value": field_metadata.get("min_value"), + "max_value": field_metadata.get("max_value"), + "placeholder": field_metadata.get("placeholder", ""), + "help_text": field_metadata.get("help_text", ""), + } + ) + + return action_fields + + for field_name, field_info in action_cls.__dataclass_fields__.items(): + if field_name == "metadata": + continue + + field_type = field_info.type + field_metadata = _extract_field_metadata(field_name, field_info) + + # Determine input type based on field type + input_type = _determine_input_type(field_type) + + # Check if field is required + is_required = field_info.default is field_info.default_factory + + action_fields.append( + { + "name": field_name, + "type": input_type, + "required": is_required, + "description": field_metadata.get("description", ""), + "default_value": field_metadata.get("default_value"), + "choices": field_metadata.get("choices", []), + "min_value": field_metadata.get("min_value"), + "max_value": field_metadata.get("max_value"), + "placeholder": field_metadata.get("placeholder", ""), + "help_text": field_metadata.get("help_text", ""), + } + ) + + return action_fields + + +def _extract_field_metadata(field_name: str, field_info) -> Dict[str, Any]: + """Extract metadata from Pydantic field including description and type hints.""" + from typing import get_origin, get_args, Literal, Union + + metadata = {} + + # Extract description from Pydantic field description + if hasattr(field_info, "description") and field_info.description: + metadata["description"] = field_info.description + + # Extract default value + if hasattr(field_info, "default") and field_info.default is not None: + metadata["default_value"] = field_info.default + + # Extract type information + field_type = field_info.annotation + origin = get_origin(field_type) + + # Handle Literal types for dropdown choices + if origin is Literal: + args = get_args(field_type) + metadata["choices"] = list(args) + + # Handle Optional types + if origin is Union: + args = get_args(field_type) + if len(args) == 2 and type(None) in args: + # This is Optional[SomeType] + non_none_type = args[0] if args[1] is type(None) else args[1] + metadata["optional"] = True + # Recursively check non-None type for choices + if get_origin(non_none_type) is Literal: + metadata["choices"] = list(get_args(non_none_type)) + else: + # Regular Union type + metadata["choices"] = [str(arg) for arg in args if arg is not type(None)] + + # Handle numeric constraints from Pydantic field + if hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra: + # Extract constraints from json_schema_extra if available + schema_extra = field_info.json_schema_extra + if "ge" in schema_extra: + metadata["min_value"] = schema_extra["ge"] + if "le" in schema_extra: + metadata["max_value"] = schema_extra["le"] + + # Handle numeric constraints based on type + if field_type in (int, float): + # Check for common constraint patterns in field name + if "count" in field_name.lower() or "num" in field_name.lower(): + metadata.setdefault("min_value", 0) + if "id" in field_name.lower(): + metadata.setdefault("min_value", 0) + + # Generate placeholder text + if "message" in field_name.lower(): + metadata["placeholder"] = f"Enter {field_name.replace('_', ' ')}..." + elif "code" in field_name.lower(): + metadata["placeholder"] = "Enter Python code here..." + elif "tokens" in field_name.lower(): + metadata["placeholder"] = "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" + else: + metadata["placeholder"] = f"Enter {field_name.replace('_', ' ')}..." + + # Generate help text based on field name and type + if "action_id" in field_name.lower(): + metadata["help_text"] = "The action ID to execute in environment" + elif "game_name" in field_name.lower(): + metadata["help_text"] = "Name of game or environment" + elif "tokens" in field_name.lower(): + metadata["help_text"] = "Token IDs as a comma-separated list of integers" + elif "code" in field_name.lower(): + metadata["help_text"] = "Python code to execute in environment" + elif "message" in field_name.lower(): + metadata["help_text"] = "Text message to send" + + return metadata + + +def _determine_input_type(field_type) -> str: + """Determine the appropriate HTML input type for a field type.""" + from typing import get_origin, get_args, Literal, Union + + # Handle direct types + if field_type is str: + return "text" + elif field_type is int: + return "number" + elif field_type is float: + return "number" + elif field_type is bool: + return "checkbox" + + # Handle complex types + origin = get_origin(field_type) + + if origin is Literal: + return "select" + elif origin is Union: + args = get_args(field_type) + if len(args) == 2 and type(None) in args: + # Optional type - use the non-None type + non_none_type = args[0] if args[1] is type(None) else args[1] + return _determine_input_type(non_none_type) + elif all(isinstance(arg, str) for arg in args if arg is not type(None)): + return "select" + else: + return "text" + elif hasattr(field_type, "__name__") and "Tensor" in field_type.__name__: + return "tensor" + else: + return "text" + + +def _markdown_to_html(markdown: str) -> str: + """Convert basic markdown to HTML for README display.""" + import html + import re + + # Escape HTML first + html_content = html.escape(markdown) + + # Convert headers + html_content = re.sub( + r"^# (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE + ) + html_content = re.sub( + r"^## (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE + ) + html_content = re.sub( + r"^### (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE + ) + + # Convert code blocks + html_content = re.sub( + r"```(.*?)\n(.*?)\n```", + r"
    \2
    ", + html_content, + flags=re.DOTALL, + ) + html_content = re.sub(r"`([^`]+)`", r"\1", html_content) + + # Convert bold and italic + html_content = re.sub(r"\*\*(.*?)\*\*", r"\1", html_content) + html_content = re.sub(r"\*(.*?)\*", r"\1", html_content) + + # Convert lists + html_content = re.sub( + r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE + ) + html_content = re.sub( + r"(
  • .*
  • )", r"", html_content, flags=re.DOTALL + ) + + # Convert line breaks + html_content = html_content.replace("\n", "
    ") + + return html_content + + +def _generate_action_interface( + action_fields: List[Dict[str, Any]], is_chat_env: bool +) -> str: + """Generate either a chat interface or action form based on environment type.""" + if is_chat_env: + return _generate_chat_interface() + else: + return _generate_action_form(action_fields) + + +def _generate_chat_interface() -> str: + """Generate a chat-style interface for chat environments.""" + return """ + +
    +

    Chat Interface

    +
    +
    +
    System
    +
    Chat environment ready. Send a message to start the conversation.
    +
    +
    +
    +
    + + +
    +
    + + +
    +
    +
    + """ + + +def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str: + """Generate a traditional action form for non-chat environments.""" + return f""" + +
    +

    Take Action

    +
    + {_generate_action_form_fields(action_fields)} + +
    +
    + """ + + +def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str: + """Generate HTML form fields for action input with enhanced metadata.""" + if not action_fields: + return "

    No action fields available

    " + + fields_html = [] + for field in action_fields: + field_html = _generate_single_field(field) + fields_html.append(field_html) + + return "\n".join(fields_html) + + +def _generate_single_field(field: Dict[str, Any]) -> str: + """Generate HTML for a single form field with enhanced metadata.""" + field_name = field["name"] + field_type = field["type"] + required = field["required"] + placeholder = field.get("placeholder", "") + help_text = field.get("help_text", "") + choices = field.get("choices", []) + min_value = field.get("min_value") + max_value = field.get("max_value") + default_value = field.get("default_value") + + # Build label with required indicator + label_text = field_name.replace("_", " ").title() + if required: + label_text += ' *' + + # Build input attributes + input_attrs = [] + if required: + input_attrs.append("required") + if placeholder: + input_attrs.append(f'placeholder="{placeholder}"') + if min_value is not None: + input_attrs.append(f'min="{min_value}"') + if max_value is not None: + input_attrs.append(f'max="{max_value}"') + if default_value is not None: + input_attrs.append(f'value="{default_value}"') + + attrs_str = " ".join(input_attrs) + + if field_type == "checkbox": + return f''' +
    + + {f'{help_text}' if help_text else ""} +
    + ''' + + elif field_type == "select": + options_html = [] + if not required: + options_html.append(f'') + + for choice in choices: + selected = "selected" if str(choice) == str(default_value) else "" + options_html.append( + f'' + ) + + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' + + elif field_type == "tensor": + return f''' +
    + + + {help_text or "Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)"} +
    + ''' + + elif field_type == "text" and ( + "message" in field_name.lower() or "code" in field_name.lower() + ): + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' + + else: + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' From f15fbdfe4bc287d05d977079763b3795241397b1 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 17 Nov 2025 05:42:38 +0000 Subject: [PATCH 03/11] fix: specify type for state in get_state --- src/core/env_server/http_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 5a0daba2..81c3bbfd 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -22,7 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, status from .interfaces import Environment -from .types import Action, Observation +from .types import Action, Observation, State class HTTPEnvServer: @@ -118,7 +118,7 @@ async def step(request: Dict[str, Any]) -> Dict[str, Any]: @app.get("/state") async def get_state() -> Dict[str, Any]: """State endpoint - returns current environment state.""" - state = self.env.state + state: State = self.env.state return state.model_dump() @app.get("/health") From 522b2aef48bccc2fa2e4aaf9754845a9bf163e1b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Mon, 17 Nov 2025 05:43:44 +0000 Subject: [PATCH 04/11] refactor: migrate echo_env to use Pydantic --- src/envs/echo_env/models.py | 81 ++++---- src/envs/echo_env/server/echo_environment.py | 204 +++++++++---------- 2 files changed, 147 insertions(+), 138 deletions(-) diff --git a/src/envs/echo_env/models.py b/src/envs/echo_env/models.py index c962629b..88f5da5e 100644 --- a/src/envs/echo_env/models.py +++ b/src/envs/echo_env/models.py @@ -1,36 +1,45 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Data models for the Echo Environment. - -The Echo environment is a simple test environment that echoes back messages. -""" - -from dataclasses import dataclass - -# Support both in-repo and standalone imports -try: - # In-repo imports (when running from OpenEnv repository) - from core.env_server.types import Action, Observation -except ImportError: - # Standalone imports (when environment is standalone with openenv-core from pip) - from openenv_core.env_server.types import Action, Observation - - -@dataclass(kw_only=True) -class EchoAction(Action): - """Action for the Echo environment - just a message to echo.""" - - message: str - - -@dataclass(kw_only=True) -class EchoObservation(Observation): - """Observation from the Echo environment - the echoed message.""" - - echoed_message: str - message_length: int = 0 \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Data models for the Echo Environment. + +The Echo environment is a simple test environment that echoes back messages. +""" + +from pydantic import Field + +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from core.env_server.types import Action, Observation +except ImportError: + # Standalone imports (when environment is standalone with openenv-core from pip) + from openenv_core.env_server.types import Action, Observation + + +class EchoAction(Action): + """Action for the Echo environment - just a message to echo.""" + + message: str = Field( + ..., + min_length=1, + description="Message to echo back" + ) + + +class EchoObservation(Observation): + """Observation from the Echo environment - the echoed message.""" + + echoed_message: str = Field( + ..., + description="The echoed message from the environment" + ) + message_length: int = Field( + default=0, + ge=0, + description="Length of the echoed message" + ) \ No newline at end of file diff --git a/src/envs/echo_env/server/echo_environment.py b/src/envs/echo_env/server/echo_environment.py index 53b383af..b1eb9619 100644 --- a/src/envs/echo_env/server/echo_environment.py +++ b/src/envs/echo_env/server/echo_environment.py @@ -1,102 +1,102 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Echo Environment Implementation. - -A simple test environment that echoes back messages sent to it. -Perfect for testing HTTP server infrastructure. -""" - -from uuid import uuid4 - -# Support both in-repo and standalone imports -try: - # In-repo imports (when running from OpenEnv repository) - from core.env_server.interfaces import Environment - from core.env_server.types import State - from ..models import EchoAction, EchoObservation -except ImportError: - # Standalone imports (when environment is standalone with openenv-core from pip) - from openenv_core.env_server.interfaces import Environment - from openenv_core.env_server.types import State - from models import EchoAction, EchoObservation - - -class EchoEnvironment(Environment): - """ - A simple echo environment that echoes back messages. - - This environment is designed for testing the HTTP server infrastructure. - It maintains minimal state and simply echoes back whatever message it receives. - - Example: - >>> env = EchoEnvironment() - >>> obs = env.reset() - >>> print(obs.echoed_message) # "Echo environment ready!" - >>> - >>> obs = env.step(EchoAction(message="Hello")) - >>> print(obs.echoed_message) # "Hello" - >>> print(obs.message_length) # 5 - """ - - def __init__(self): - """Initialize the echo environment.""" - self._state = State(episode_id=str(uuid4()), step_count=0) - self._reset_count = 0 - - def reset(self) -> EchoObservation: - """ - Reset the environment. - - Returns: - EchoObservation with a ready message - """ - self._state = State(episode_id=str(uuid4()), step_count=0) - self._reset_count += 1 - - return EchoObservation( - echoed_message="Echo environment ready!", - message_length=0, - done=False, - reward=0.0, - ) - - def step(self, action: EchoAction) -> EchoObservation: # type: ignore[override] - """ - Execute a step in the environment by echoing the message. - - Args: - action: EchoAction containing the message to echo - - Returns: - EchoObservation with the echoed message and its length - """ - self._state.step_count += 1 - - message = action.message - length = len(message) - - # Simple reward: longer messages get higher rewards - reward = length * 0.1 - - return EchoObservation( - echoed_message=message, - message_length=length, - done=False, - reward=reward, - metadata={"original_message": message, "step": self._state.step_count}, - ) - - @property - def state(self) -> State: - """ - Get the current environment state. - - Returns: - Current State with episode_id and step_count - """ - return self._state +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Echo Environment Implementation. + +A simple test environment that echoes back messages sent to it. +Perfect for testing HTTP server infrastructure. +""" + +from uuid import uuid4 + +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from core.env_server.interfaces import Environment + from core.env_server.types import State + from ..models import EchoAction, EchoObservation +except ImportError: + # Standalone imports (when environment is standalone with openenv-core from pip) + from openenv_core.env_server.interfaces import Environment + from openenv_core.env_server.types import State + from models import EchoAction, EchoObservation + + +class EchoEnvironment(Environment): + """ + A simple echo environment that echoes back messages. + + This environment is designed for testing the HTTP server infrastructure. + It maintains minimal state and simply echoes back whatever message it receives. + + Example: + >>> env = EchoEnvironment() + >>> obs = env.reset() + >>> print(obs.echoed_message) # "Echo environment ready!" + >>> + >>> obs = env.step(EchoAction(message="Hello")) + >>> print(obs.echoed_message) # "Hello" + >>> print(obs.message_length) # 5 + """ + + def __init__(self): + """Initialize the echo environment.""" + self._state: State = State(episode_id=str(uuid4()), step_count=0) + self._reset_count: int = 0 + + def reset(self) -> EchoObservation: + """ + Reset the environment. + + Returns: + EchoObservation with a ready message + """ + self._state: State = State(episode_id=str(uuid4()), step_count=0) + self._reset_count += 1 + + return EchoObservation( + echoed_message="Echo environment ready!", + message_length=0, + done=False, + reward=0.0, + ) + + def step(self, action: EchoAction) -> EchoObservation: # type: ignore[override] + """ + Execute a step in the environment by echoing the message. + + Args: + action: EchoAction containing the message to echo + + Returns: + EchoObservation with the echoed message and its length + """ + self._state.step_count += 1 + + message: str = action.message + length: int = len(message) + + # Simple reward: longer messages get higher rewards + reward: float = length * 0.1 + + return EchoObservation( + echoed_message=message, + message_length=length, + done=False, + reward=reward, + metadata={"original_message": message, "step": self._state.step_count}, + ) + + @property + def state(self) -> State: + """ + Get the current environment state. + + Returns: + Current State with episode_id and step_count + """ + return self._state From ff1bd7c6439c9020cc5488a5fd138bdcb86ddc56 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Tue, 18 Nov 2025 06:52:30 +0000 Subject: [PATCH 05/11] feat: endpoints to retrieve JSON schemas actions, observations, and state --- src/core/env_server/http_server.py | 565 ++++++++++++++++------------- 1 file changed, 304 insertions(+), 261 deletions(-) diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 81c3bbfd..9a4e6f6b 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -1,261 +1,304 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -HTTP server wrapper for Environment instances. - -This module provides utilities to wrap any Environment subclass and expose it -over HTTP endpoints that HTTPEnvClient can consume. -""" - -from __future__ import annotations - -import asyncio -import os -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, Type, Optional - -from pydantic import ValidationError -from fastapi import Body, FastAPI, HTTPException, status - -from .interfaces import Environment -from .types import Action, Observation, State - - -class HTTPEnvServer: - """ - HTTP server wrapper for Environment instances. - - This class wraps an Environment and exposes its reset(), step(), and state - methods as HTTP endpoints compatible with HTTPEnvClient. - - The server expects: - - Action deserialization: Converts JSON dict to Action subclass - - Observation serialization: Converts Observation subclass to JSON dict - - Example: - >>> from core.env_server import HTTPEnvServer - >>> from envs.coding_env.server import CodeExecutionEnvironment - >>> - >>> env = CodeExecutionEnvironment() - >>> server = HTTPEnvServer(env) - >>> - >>> # Register routes with FastAPI - >>> from fastapi import FastAPI - >>> app = FastAPI() - >>> server.register_routes(app) - """ - - def __init__( - self, - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - ): - """ - Initialize HTTP server wrapper. - - Args: - env: The Environment instance to wrap - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - """ - self.env = env - self.action_cls = action_cls - self.observation_cls = observation_cls - # Create thread pool for running sync code in async context - # This is needed for environments using sync libraries (e.g., Playwright sync API) - self._executor = ThreadPoolExecutor(max_workers=1) - - def register_routes(self, app: Any) -> None: - """ - Register HTTP routes on a FastAPI application. - - Args: - app: FastAPI application instance - """ - - if not isinstance(app, FastAPI): - raise TypeError("app must be a FastAPI instance") - - @app.post("/reset") - async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]: - """Reset endpoint - returns initial observation.""" - # TODO: Handle seed, episode_id from request if provided - # Run sync environment code in thread pool to avoid blocking asyncio loop - loop = asyncio.get_event_loop() - observation = await loop.run_in_executor(self._executor, self.env.reset) - return self._serialize_observation(observation) - - @app.post("/step") - async def step(request: Dict[str, Any]) -> Dict[str, Any]: - """Step endpoint - executes action and returns observation.""" - # Support both {"action": {...}} and direct action fields - action_data = request.get("action", request) - # TODO: Handle timeout_s, request_id, episode_id from request if provided - - # Deserialize action with Pydantic validation - try: - action = self._deserialize_action(action_data) - except ValidationError as e: - # Return HTTP 422 with detailed validation errors - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors() - ) - - # Execute step in thread pool to avoid blocking asyncio loop - loop = asyncio.get_event_loop() - observation = await loop.run_in_executor( - self._executor, self.env.step, action - ) - - # Return serialized observation - return self._serialize_observation(observation) - - @app.get("/state") - async def get_state() -> Dict[str, Any]: - """State endpoint - returns current environment state.""" - state: State = self.env.state - return state.model_dump() - - @app.get("/health") - async def health() -> Dict[str, str]: - """Health check endpoint.""" - return {"status": "healthy"} - - def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: - """ - Convert JSON dict to Action instance using Pydantic validation. - - Args: - action_data: Dictionary containing action data - - Returns: - Action instance - - Raises: - ValidationError: If action_data is invalid for the action class - - Note: - This uses Pydantic's model_validate() for automatic validation. - """ - # Pydantic handles validation automatically - action = self.action_cls.model_validate(action_data) - return action - - def _serialize_observation(self, observation: Observation) -> Dict[str, Any]: - """ - Convert Observation instance to JSON-compatible dict using Pydantic. - - Args: - observation: Observation instance - - Returns: - Dictionary compatible with HTTPEnvClient._parse_result() - - The format matches what HTTPEnvClient expects: - { - "observation": {...}, # Observation fields - "reward": float | None, - "done": bool, - } - """ - # Use Pydantic's model_dump() for serialization - obs_dict = observation.model_dump( - exclude={ - "reward", - "done", - "metadata", - } # Exclude these from observation dict - ) - - # Extract reward and done directly from the observation - reward = observation.reward - done = observation.done - - # Return in HTTPEnvClient expected format - return { - "observation": obs_dict, - "reward": reward, - "done": done, - } - - -def create_app( - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - env_name: Optional[str] = None, -) -> Any: - """ - Create a FastAPI application with or without web interface. - - This function creates a FastAPI app with the web interface enabled by default, - including README integration for better user experience. - - Args: - env: The Environment instance to serve - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - env_name: Optional environment name for README loading - - Returns: - FastAPI application instance with or without web interface and README integration - """ - # Check if web interface should be enabled - # This can be controlled via environment variable or build argument - enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ( - "true", - "1", - "yes", - ) - - if enable_web: - # Import web interface only when needed - from .web_interface import create_web_interface_app - - return create_web_interface_app(env, action_cls, observation_cls, env_name) - else: - # Use standard FastAPI app without web interface - return create_fastapi_app(env, action_cls, observation_cls) - - -def create_fastapi_app( - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], -) -> Any: - """ - Create a FastAPI application with routes for the given environment. - - Args: - env: The Environment instance to serve - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - - Returns: - FastAPI application instance with routes registered - - Example: - >>> from envs.coding_env.server import CodeExecutionEnvironment - >>> from envs.coding_env.models import CodeAction, CodeObservation - >>> - >>> env = CodeExecutionEnvironment() - >>> app = create_fastapi_app(env, CodeAction, CodeObservation) - >>> - >>> # Run with: uvicorn module:app --host 0.0.0.0 --port 8000 - """ - try: - from fastapi import FastAPI - except ImportError: - raise ImportError( - "FastAPI is required. Install with: pip install fastapi uvicorn" - ) - - app = FastAPI(title="Environment HTTP Server") - server = HTTPEnvServer(env, action_cls, observation_cls) - server.register_routes(app) - return app +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +HTTP server wrapper for Environment instances. + +This module provides utilities to wrap any Environment subclass and expose it +over HTTP endpoints that HTTPEnvClient can consume. +""" + +from __future__ import annotations + +import asyncio +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Type, Optional + +from pydantic import ValidationError +from fastapi import Body, FastAPI, HTTPException, status + +from .interfaces import Environment +from .types import Action, Observation, State + + +class HTTPEnvServer: + """ + HTTP server wrapper for Environment instances. + + This class wraps an Environment and exposes its reset(), step(), and state + methods as HTTP endpoints compatible with HTTPEnvClient. + + The server expects: + - Action deserialization: Converts JSON dict to Action subclass + - Observation serialization: Converts Observation subclass to JSON dict + + Example: + >>> from core.env_server import HTTPEnvServer + >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> + >>> env = CodeExecutionEnvironment() + >>> server = HTTPEnvServer(env) + >>> + >>> # Register routes with FastAPI + >>> from fastapi import FastAPI + >>> app = FastAPI() + >>> server.register_routes(app) + """ + + def __init__( + self, + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + ): + """ + Initialize HTTP server wrapper. + + Args: + env: The Environment instance to wrap + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + """ + self.env = env + self.action_cls = action_cls + self.observation_cls = observation_cls + # Create thread pool for running sync code in async context + # This is needed for environments using sync libraries (e.g., Playwright sync API) + self._executor = ThreadPoolExecutor(max_workers=1) + + def register_routes(self, app: Any) -> None: + """ + Register HTTP routes on a FastAPI application. + + Args: + app: FastAPI application instance + """ + + if not isinstance(app, FastAPI): + raise TypeError("app must be a FastAPI instance") + + @app.post("/reset") + async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]: + """Reset endpoint - returns initial observation.""" + # TODO: Handle seed, episode_id from request if provided + # Run sync environment code in thread pool to avoid blocking asyncio loop + loop = asyncio.get_event_loop() + observation = await loop.run_in_executor(self._executor, self.env.reset) + return self._serialize_observation(observation) + + @app.post("/step") + async def step(request: Dict[str, Any]) -> Dict[str, Any]: + """Step endpoint - executes action and returns observation.""" + # Support both {"action": {...}} and direct action fields + action_data = request.get("action", request) + # TODO: Handle timeout_s, request_id, episode_id from request if provided + + # Deserialize action with Pydantic validation + try: + action = self._deserialize_action(action_data) + except ValidationError as e: + # Return HTTP 422 with detailed validation errors + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors() + ) + + # Execute step in thread pool to avoid blocking asyncio loop + loop = asyncio.get_event_loop() + observation = await loop.run_in_executor( + self._executor, self.env.step, action + ) + + # Return serialized observation + return self._serialize_observation(observation) + + @app.get("/state") + async def get_state() -> Dict[str, Any]: + """State endpoint - returns current environment state.""" + state: State = self.env.state + return state.model_dump() + + @app.get("/health") + async def health() -> Dict[str, str]: + """Health check endpoint.""" + return {"status": "healthy"} + + @app.get("/schema/action", tags=["Schema"]) + async def get_action_schema() -> Dict[str, Any]: + """ + Get JSON schema for actions accepted by this environment. + + Returns the complete JSON schema definition for the Action model, + including all field types, constraints, and validation rules. + This schema can be used to validate actions before sending them + to the environment, or to generate forms in web interfaces. + + Returns: + Dict containing JSON Schema + """ + return self.action_cls.model_json_schema() + + @app.get("/schema/observation", tags=["Schema"]) + async def get_observation_schema() -> Dict[str, Any]: + """ + Get JSON schema for observations returned by this environment. + + Returns the complete JSON schema definition for the Observation model, + including all field types and nested structures. This schema describes + what observations the environment will return after actions are executed. + + Returns: + Dict containing JSON Schema + """ + return self.observation_cls.model_json_schema() + + @app.get("/schema/state", tags=["Schema"]) + async def get_state_schema() -> Dict[str, Any]: + """ + Get JSON schema for environment state objects. + + Returns the complete JSON schema definition for the State model. + This schema describes the internal state representation of the + environment, which can be queried via the /state endpoint. + + Returns: + Dict containing JSON Schema + """ + return State.model_json_schema() + + def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: + """ + Convert JSON dict to Action instance using Pydantic validation. + + Args: + action_data: Dictionary containing action data + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + + Note: + This uses Pydantic's model_validate() for automatic validation. + """ + # Pydantic handles validation automatically + action = self.action_cls.model_validate(action_data) + return action + + def _serialize_observation(self, observation: Observation) -> Dict[str, Any]: + """ + Convert Observation instance to JSON-compatible dict using Pydantic. + + Args: + observation: Observation instance + + Returns: + Dictionary compatible with HTTPEnvClient._parse_result() + + The format matches what HTTPEnvClient expects: + { + "observation": {...}, # Observation fields + "reward": float | None, + "done": bool, + } + """ + # Use Pydantic's model_dump() for serialization + obs_dict = observation.model_dump( + exclude={ + "reward", + "done", + "metadata", + } # Exclude these from observation dict + ) + + # Extract reward and done directly from the observation + reward = observation.reward + done = observation.done + + # Return in HTTPEnvClient expected format + return { + "observation": obs_dict, + "reward": reward, + "done": done, + } + + +def create_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, +) -> Any: + """ + Create a FastAPI application with or without web interface. + + This function creates a FastAPI app with the web interface enabled by default, + including README integration for better user experience. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + + Returns: + FastAPI application instance with or without web interface and README integration + """ + # Check if web interface should be enabled + # This can be controlled via environment variable or build argument + enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ( + "true", + "1", + "yes", + ) + + if enable_web: + # Import web interface only when needed + from .web_interface import create_web_interface_app + + return create_web_interface_app(env, action_cls, observation_cls, env_name) + else: + # Use standard FastAPI app without web interface + return create_fastapi_app(env, action_cls, observation_cls) + + +def create_fastapi_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], +) -> Any: + """ + Create a FastAPI application with routes for the given environment. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + + Returns: + FastAPI application instance with routes registered + + Example: + >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> from envs.coding_env.models import CodeAction, CodeObservation + >>> + >>> env = CodeExecutionEnvironment() + >>> app = create_fastapi_app(env, CodeAction, CodeObservation) + >>> + >>> # Run with: uvicorn module:app --host 0.0.0.0 --port 8000 + """ + try: + from fastapi import FastAPI + except ImportError: + raise ImportError( + "FastAPI is required. Install with: pip install fastapi uvicorn" + ) + + app = FastAPI(title="Environment HTTP Server") + server = HTTPEnvServer(env, action_cls, observation_cls) + server.register_routes(app) + return app From 82acaf28194cdf7580e71dc3fd44050731a1f8ef Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Tue, 18 Nov 2025 06:56:39 +0000 Subject: [PATCH 06/11] feat: request and response models for reset and step endpoints --- src/core/env_server/http_server.py | 88 ++++++++--- src/core/env_server/interfaces.py | 246 +++++++++++++++-------------- src/core/env_server/types.py | 67 ++++++++ 3 files changed, 258 insertions(+), 143 deletions(-) diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 9a4e6f6b..9d1fec9b 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -14,15 +14,24 @@ from __future__ import annotations import asyncio +import inspect import os from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, Type, Optional +from typing import Any, Dict, Optional, Type -from pydantic import ValidationError from fastapi import Body, FastAPI, HTTPException, status +from pydantic import ValidationError from .interfaces import Environment -from .types import Action, Observation, State +from .types import ( + Action, + Observation, + ResetRequest, + ResetResponse, + State, + StepRequest, + StepResponse, +) class HTTPEnvServer: @@ -81,21 +90,37 @@ def register_routes(self, app: Any) -> None: if not isinstance(app, FastAPI): raise TypeError("app must be a FastAPI instance") - @app.post("/reset") - async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]: + @app.post("/reset", response_model=ResetResponse) + async def reset( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: """Reset endpoint - returns initial observation.""" - # TODO: Handle seed, episode_id from request if provided - # Run sync environment code in thread pool to avoid blocking asyncio loop - loop = asyncio.get_event_loop() - observation = await loop.run_in_executor(self._executor, self.env.reset) - return self._serialize_observation(observation) - - @app.post("/step") - async def step(request: Dict[str, Any]) -> Dict[str, Any]: + # Handle optional parameters + kwargs = {} + if request.seed is not None: + kwargs["seed"] = request.seed + if request.episode_id is not None: + kwargs["episode_id"] = request.episode_id + + # Pass arguments only if environment accepts them + sig = inspect.signature(self.env.reset) + valid_kwargs = {} + + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + + for k, v in kwargs.items(): + if k in sig.parameters or has_kwargs: + valid_kwargs[k] = v + + observation = self.env.reset(**valid_kwargs) + return ResetResponse(**self._serialize_observation(observation)) + + @app.post("/step", response_model=StepResponse) + async def step(request: StepRequest) -> StepResponse: """Step endpoint - executes action and returns observation.""" - # Support both {"action": {...}} and direct action fields - action_data = request.get("action", request) - # TODO: Handle timeout_s, request_id, episode_id from request if provided + action_data = request.action # Deserialize action with Pydantic validation try: @@ -106,20 +131,33 @@ async def step(request: Dict[str, Any]) -> Dict[str, Any]: status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors() ) - # Execute step in thread pool to avoid blocking asyncio loop - loop = asyncio.get_event_loop() - observation = await loop.run_in_executor( - self._executor, self.env.step, action + # Handle optional parameters + kwargs = {} + if request.timeout_s is not None: + kwargs["timeout_s"] = request.timeout_s + + # Pass arguments only if environment accepts them + sig = inspect.signature(self.env.step) + valid_kwargs = {} + + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) + for k, v in kwargs.items(): + if k in sig.parameters or has_kwargs: + valid_kwargs[k] = v + + # Execute step + observation = self.env.step(action, **valid_kwargs) + # Return serialized observation - return self._serialize_observation(observation) + return StepResponse(**self._serialize_observation(observation)) - @app.get("/state") - async def get_state() -> Dict[str, Any]: + @app.get("/state", response_model=State) + async def get_state() -> State: """State endpoint - returns current environment state.""" - state: State = self.env.state - return state.model_dump() + return self.env.state @app.get("/health") async def health() -> Dict[str, str]: diff --git a/src/core/env_server/interfaces.py b/src/core/env_server/interfaces.py index caa2d76d..afcbdde9 100644 --- a/src/core/env_server/interfaces.py +++ b/src/core/env_server/interfaces.py @@ -1,118 +1,128 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import Any, Protocol, TypedDict - -from .types import Action, Observation, State - - -class Message(TypedDict): - """A message in a conversation. - - Compatible with Huggingface chat template format. - """ - - role: str - content: str - - -class ModelTokenizer(Protocol): - """Protocol for tokenizers that support chat templates. - - This protocol defines the interface that tokenizers must implement - to work with chat-based environments. It's compatible with - Huggingface transformers tokenizers. - """ - - def apply_chat_template( - self, - conversation: list[Message], - tokenize: bool = True, - return_tensors: str | None = None, - **kwargs: Any, - ) -> Any: - """Apply a chat template to format and optionally tokenize a conversation. - - Args: - conversation: List of message dictionaries with 'role' and 'content' - tokenize: Whether to tokenize the output - return_tensors: Format for returned tensors ('pt' for PyTorch) - **kwargs: Additional arguments - - Returns: - Formatted and optionally tokenized conversation - """ - ... - - def decode( - self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any - ) -> str: - """Decode token IDs back to text. - - Args: - token_ids: Token IDs to decode - skip_special_tokens: Whether to skip special tokens in output - **kwargs: Additional arguments - - Returns: - Decoded text string - """ - ... - - -class Transform(ABC): - """Transform observations to add rewards, metrics, or other modifications. - - Transforms follow the TorchRL pattern where they take an observation - and return a (potentially modified) observation. This allows for - flexible reward computation and observation augmentation. - """ - - @abstractmethod - def __call__(self, observation: Observation) -> Observation: - """Transform an observation. - - Args: - observation: The input observation - - Returns: - The transformed observation - """ - pass - - -class Environment(ABC): - """Base class for all environment servers following Gym/Gymnasium API. - - Args: - transform: Optional transform to apply to observations - """ - - def __init__(self, transform: Transform | None = None): - self.transform = transform - - @abstractmethod - def reset(self) -> Observation: - """Reset the environment and return initial observation.""" - pass - - @abstractmethod - def step(self, action: Action) -> Observation: - """Take a step in the environment.""" - pass - - @property - @abstractmethod - def state(self) -> State: - """Get the current environment state.""" - pass - - def _apply_transform(self, observation: Observation) -> Observation: - """Apply transform if one is provided.""" - if self.transform is not None: - return self.transform(observation) - return observation +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any, Optional, Protocol, TypedDict + +from .types import Action, Observation, State + + +class Message(TypedDict): + """A message in a conversation. + + Compatible with Huggingface chat template format. + """ + + role: str + content: str + + +class ModelTokenizer(Protocol): + """Protocol for tokenizers that support chat templates. + + This protocol defines the interface that tokenizers must implement + to work with chat-based environments. It's compatible with + Huggingface transformers tokenizers. + """ + + def apply_chat_template( + self, + conversation: list[Message], + tokenize: bool = True, + return_tensors: str | None = None, + **kwargs: Any, + ) -> Any: + """Apply a chat template to format and optionally tokenize a conversation. + + Args: + conversation: List of message dictionaries with 'role' and 'content' + tokenize: Whether to tokenize the output + return_tensors: Format for returned tensors ('pt' for PyTorch) + **kwargs: Additional arguments + + Returns: + Formatted and optionally tokenized conversation + """ + ... + + def decode( + self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any + ) -> str: + """Decode token IDs back to text. + + Args: + token_ids: Token IDs to decode + skip_special_tokens: Whether to skip special tokens in output + **kwargs: Additional arguments + + Returns: + Decoded text string + """ + ... + + +class Transform(ABC): + """Transform observations to add rewards, metrics, or other modifications. + + Transforms follow the TorchRL pattern where they take an observation + and return a (potentially modified) observation. This allows for + flexible reward computation and observation augmentation. + """ + + @abstractmethod + def __call__(self, observation: Observation) -> Observation: + """Transform an observation. + + Args: + observation: The input observation + + Returns: + The transformed observation + """ + pass + + +class Environment(ABC): + """Base class for all environment servers following Gym/Gymnasium API. + + Args: + transform: Optional transform to apply to observations + """ + + def __init__(self, transform: Transform | None = None): + self.transform = transform + + @abstractmethod + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> Observation: + """Reset the environment and return initial observation.""" + pass + + @abstractmethod + def step( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """Take a step in the environment.""" + pass + + @property + @abstractmethod + def state(self) -> State: + """Get the current environment state.""" + pass + + def _apply_transform(self, observation: Observation) -> Observation: + """Apply transform if one is provided.""" + if self.transform is not None: + return self.transform(observation) + return observation diff --git a/src/core/env_server/types.py b/src/core/env_server/types.py index 2a3256d5..0cde1197 100644 --- a/src/core/env_server/types.py +++ b/src/core/env_server/types.py @@ -52,6 +52,73 @@ class Observation(BaseModel): ) +class ResetRequest(BaseModel): + """Request model for environment reset.""" + + model_config = ConfigDict( + extra="forbid", + json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]}, + ) + + seed: Optional[int] = Field( + default=None, ge=0, description="Random seed for reproducible episodes" + ) + episode_id: Optional[str] = Field( + default=None, max_length=255, description="Custom episode identifier" + ) + + +class ResetResponse(BaseModel): + """Response model for environment reset.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Initial observation from the environment" + ) + reward: Optional[float] = Field( + default=None, description="Initial reward (typically None at reset)" + ) + done: bool = Field( + default=False, description="Whether episode is already done (typically False)" + ) + + +class StepRequest(BaseModel): + """Request model for environment step.""" + + model_config = ConfigDict(extra="forbid") + + action: Dict[str, Any] = Field( + ..., + description="Action to execute, must conform to environment's action schema", + ) + timeout_s: Optional[float] = Field( + default=None, + gt=0, + description="Optional timeout in seconds for action execution", + ) + request_id: Optional[str] = Field( + default=None, + max_length=255, + description="Optional request identifier for tracking", + ) + + +class StepResponse(BaseModel): + """Response model for environment step.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Observation resulting from the action" + ) + reward: Optional[float] = Field( + default=None, description="Reward signal from the action" + ) + done: bool = Field(default=False, description="Whether the episode has terminated") + + class State(BaseModel): """Base class for environment state. From 04eb97b2bc513cbc8147a6fcc538525913f5bc10 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Tue, 18 Nov 2025 09:05:57 +0000 Subject: [PATCH 07/11] feat: extra fields in reset and step request models for custom params --- src/core/env_server/http_server.py | 12 +- src/core/env_server/types.py | 12 +- src/core/http_env_client.py | 439 ++++++++++++++++------------- 3 files changed, 250 insertions(+), 213 deletions(-) diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 9d1fec9b..204aee74 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -96,11 +96,8 @@ async def reset( ) -> ResetResponse: """Reset endpoint - returns initial observation.""" # Handle optional parameters - kwargs = {} - if request.seed is not None: - kwargs["seed"] = request.seed - if request.episode_id is not None: - kwargs["episode_id"] = request.episode_id + # Start with all fields from the request, including extra ones + kwargs = request.model_dump(exclude_unset=True) # Pass arguments only if environment accepts them sig = inspect.signature(self.env.reset) @@ -132,9 +129,8 @@ async def step(request: StepRequest) -> StepResponse: ) # Handle optional parameters - kwargs = {} - if request.timeout_s is not None: - kwargs["timeout_s"] = request.timeout_s + # Start with all fields from the request, including extra ones, but exclude 'action' + kwargs = request.model_dump(exclude_unset=True, exclude={'action'}) # Pass arguments only if environment accepts them sig = inspect.signature(self.env.step) diff --git a/src/core/env_server/types.py b/src/core/env_server/types.py index 0cde1197..d96d7baf 100644 --- a/src/core/env_server/types.py +++ b/src/core/env_server/types.py @@ -56,7 +56,7 @@ class ResetRequest(BaseModel): """Request model for environment reset.""" model_config = ConfigDict( - extra="forbid", + extra="allow", # Allow extra fields for custom reset parameters json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]}, ) @@ -87,7 +87,15 @@ class ResetResponse(BaseModel): class StepRequest(BaseModel): """Request model for environment step.""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom step parameters + json_schema_extra={ + "examples": [ + {"action": {"value": 1}, "timeout_s": 30.0}, + {"action": {"value": 1}, "render": True, "verbose": False}, + ] + }, + ) action: Dict[str, Any] = Field( ..., diff --git a/src/core/http_env_client.py b/src/core/http_env_client.py index 16bbfa5d..007ef6a5 100644 --- a/src/core/http_env_client.py +++ b/src/core/http_env_client.py @@ -1,203 +1,236 @@ -""" -core/runner_env.py -Minimal HTTP-based environment client. -- Talks to a single env worker exposing: POST /reset, POST /step - -Future hooks (commented below) for: -- episode_id, seed on reset -- request_id on step -- custom headers (auth/trace) -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar - -import requests - -from .client_types import StepResult -from .containers.runtime import LocalDockerProvider - -if TYPE_CHECKING: - from .containers.runtime import ContainerProvider - -ActT = TypeVar("ActT") -ObsT = TypeVar("ObsT") -EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") - - -class HTTPEnvClient(ABC, Generic[ActT, ObsT]): - def __init__( - self, - base_url: str, - request_timeout_s: float = 15.0, - default_headers: Optional[Dict[str, str]] = None, - provider: Optional["ContainerProvider"] = None, - ): - self._base = base_url.rstrip("/") - self._timeout = float(request_timeout_s) - self._http = requests.Session() - self._headers = default_headers or {} - self._provider = provider - - @classmethod - def from_docker_image( - cls: Type[EnvClientT], - image: str, - provider: Optional["ContainerProvider"] = None, - **kwargs: Any, - ) -> EnvClientT: - """ - Create an environment client by spinning up a Docker container locally. - - This is a development utility that: - 1. Starts a Docker container from the specified image - 2. Waits for the server to be ready - 3. Creates and returns a client instance connected to the container - - Note: The container lifecycle management is left to the user or higher-level - orchestration. The container will keep running until manually stopped. - - Args: - image: Docker image name to run (e.g., "echo-env:latest") - provider: Container provider to use (defaults to LocalDockerProvider) - **kwargs: Additional arguments to pass to provider.start_container() - (e.g., env_vars, port) - - Returns: - An instance of the client class connected to the running container - - Example: - >>> from envs.coding_env.client import CodingEnv - >>> from envs.coding_env.models import CodeAction - >>> - >>> # Create environment from image - >>> env = CodingEnv.from_docker_image("coding-env:latest") - >>> - >>> # Create environment with custom env vars - >>> env = CodingEnv.from_docker_image( - ... "coding-env:latest", - ... env_vars={"MY_VAR": "value"} - ... ) - >>> - >>> # Use the environment - >>> result = env.reset() - >>> print(result.observation) - >>> - >>> step_result = env.step(CodeAction(code="print('hello')")) - >>> print(step_result.observation.stdout) - >>> - >>> # Cleanup (optional) - >>> env.close() - """ - - # Use default provider if none provided - if provider is None: - provider = LocalDockerProvider() - - # 1. Start container with optional kwargs (e.g., env_vars, port) - base_url = provider.start_container(image, **kwargs) - - # 2. Wait for server to be ready - provider.wait_for_ready(base_url) - - # 3. Create and return client instance with provider reference - return cls(base_url=base_url, provider=provider) - - @classmethod - def from_hub(cls: Type[EnvClientT], repo_id: str, provider: Optional["ContainerProvider"] = None, **kwargs: Any) -> EnvClientT: - """ - Create an environment client by pulling from a Hugging Face model hub. - """ - - if provider is None: - provider = LocalDockerProvider() - - if "tag" in kwargs: - tag = kwargs["tag"] - else: - tag = "latest" - - base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" - - return cls.from_docker_image(image=base_url, provider=provider) - - @abstractmethod - def _step_payload(self, action: ActT) -> dict: - """Convert an Action object to the JSON body expected by the env server.""" - raise NotImplementedError - - @abstractmethod - def _parse_result(self, payload: dict) -> StepResult[ObsT]: - """Convert a JSON response from the env server to StepResult[ObsT].""" - raise NotImplementedError - - @abstractmethod - def _parse_state(self, payload: dict) -> Any: - """Convert a JSON response from the state endpoint to a State object.""" - raise NotImplementedError - - # ---------- Environment Server Interface Methods ---------- - def reset(self) -> StepResult[ObsT]: - body: Dict[str, Any] = {} - # TODO: later: - # body["seed"] = seed - # body["episode_id"] = episode_id - r = self._http.post( - f"{self._base}/reset", - json=body, - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_result(r.json()) - - def step(self, action: ActT) -> StepResult[ObsT]: - body: Dict[str, Any] = { - "action": self._step_payload(action), - "timeout_s": int(self._timeout), - } - # TODO: later: - # body["request_id"] = str(uuid.uuid4()) - # body["episode_id"] = current_episode_id - r = self._http.post( - f"{self._base}/step", - json=body, - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_result(r.json()) - - def state(self) -> Any: - """ - Get the current environment state from the server. - - Returns: - State object with environment state information (e.g., episode_id, step_count) - - Example: - >>> client = EchoEnv.from_docker_image("echo-env:latest") - >>> result = client.reset() - >>> state = client.state() - >>> print(state.episode_id) - >>> print(state.step_count) - """ - r = self._http.get( - f"{self._base}/state", - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_state(r.json()) - - def close(self) -> None: - """ - Close the environment and clean up resources. - - If this client was created via from_docker_image(), this will stop - and remove the associated container. - """ - if self._provider is not None: - self._provider.stop_container() +""" +core/runner_env.py +Minimal HTTP-based environment client. +- Talks to a single env worker exposing: POST /reset, POST /step + +Future hooks (commented below) for: +- episode_id, seed on reset +- request_id on step +- custom headers (auth/trace) +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +import requests + +from .client_types import StepResult +from .containers.runtime import LocalDockerProvider + +if TYPE_CHECKING: + from .containers.runtime import ContainerProvider + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") + + +class HTTPEnvClient(ABC, Generic[ActT, ObsT]): + def __init__( + self, + base_url: str, + request_timeout_s: float = 15.0, + default_headers: Optional[Dict[str, str]] = None, + provider: Optional["ContainerProvider"] = None, + ): + self._base = base_url.rstrip("/") + self._timeout = float(request_timeout_s) + self._http = requests.Session() + self._headers = default_headers or {} + self._provider = provider + + @classmethod + def from_docker_image( + cls: Type[EnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by spinning up a Docker container locally. + + This is a development utility that: + 1. Starts a Docker container from the specified image + 2. Waits for the server to be ready + 3. Creates and returns a client instance connected to the container + + Note: The container lifecycle management is left to the user or higher-level + orchestration. The container will keep running until manually stopped. + + Args: + image: Docker image name to run (e.g., "echo-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + (e.g., env_vars, port) + + Returns: + An instance of the client class connected to the running container + + Example: + >>> from envs.coding_env.client import CodingEnv + >>> from envs.coding_env.models import CodeAction + >>> + >>> # Create environment from image + >>> env = CodingEnv.from_docker_image("coding-env:latest") + >>> + >>> # Create environment with custom env vars + >>> env = CodingEnv.from_docker_image( + ... "coding-env:latest", + ... env_vars={"MY_VAR": "value"} + ... ) + >>> + >>> # Use the environment + >>> result = env.reset() + >>> print(result.observation) + >>> + >>> step_result = env.step(CodeAction(code="print('hello')")) + >>> print(step_result.observation.stdout) + >>> + >>> # Cleanup (optional) + >>> env.close() + """ + + # Use default provider if none provided + if provider is None: + provider = LocalDockerProvider() + + # 1. Start container with optional kwargs (e.g., env_vars, port) + base_url = provider.start_container(image, **kwargs) + + # 2. Wait for server to be ready + provider.wait_for_ready(base_url) + + # 3. Create and return client instance with provider reference + return cls(base_url=base_url, provider=provider) + + @classmethod + def from_hub( + cls: Type[EnvClientT], + repo_id: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by pulling from a Hugging Face model hub. + """ + + if provider is None: + provider = LocalDockerProvider() + + if "tag" in kwargs: + tag = kwargs["tag"] + else: + tag = "latest" + + base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + + return cls.from_docker_image(image=base_url, provider=provider) + + @abstractmethod + def _step_payload(self, action: ActT) -> dict: + """Convert an Action object to the JSON body expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: dict) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: dict) -> Any: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + # ---------- Environment Server Interface Methods ---------- + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + - Any environment-specific reset parameters + + Returns: + StepResult containing initial observation + + Example: + >>> env.reset(seed=42, episode_id="ep-001") + """ + body: Dict[str, Any] = kwargs.copy() + r = self._http.post( + f"{self._base}/reset", + json=body, + headers=self._headers, + timeout=self._timeout, + ) + r.raise_for_status() + return self._parse_result(r.json()) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment with optional parameters. + + Args: + action: The action to execute + **kwargs: Optional parameters passed to the environment's step method. + Common parameters include: + - timeout_s: Execution timeout in seconds + - request_id: Request identifier for tracking + - render: Whether to render the environment + - Any environment-specific step parameters + + Returns: + StepResult containing observation, reward, and done status + + Example: + >>> env.step(action, timeout_s=30.0, request_id="req-123", render=True) + """ + body: Dict[str, Any] = { + "action": self._step_payload(action), + **kwargs # Forward all additional parameters + } + r = self._http.post( + f"{self._base}/step", + json=body, + headers=self._headers, + timeout=self._timeout, + ) + r.raise_for_status() + return self._parse_result(r.json()) + + def state(self) -> Any: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information (e.g., episode_id, step_count) + + Example: + >>> client = EchoEnv.from_docker_image("echo-env:latest") + >>> result = client.reset() + >>> state = client.state() + >>> print(state.episode_id) + >>> print(state.step_count) + """ + r = self._http.get( + f"{self._base}/state", + headers=self._headers, + timeout=self._timeout, + ) + r.raise_for_status() + return self._parse_state(r.json()) + + def close(self) -> None: + """ + Close the environment and clean up resources. + + If this client was created via from_docker_image(), this will stop + and remove the associated container. + """ + if self._provider is not None: + self._provider.stop_container() From 4078161255593b571448e9dbca2369c077dda5ff Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:05:40 +0530 Subject: [PATCH 08/11] chore: API docs and metadata extraction for action fields --- src/core/env_server/http_server.py | 217 +++++++++++++++++++++---- src/core/env_server/web_interface.py | 226 ++++++++++----------------- 2 files changed, 269 insertions(+), 174 deletions(-) diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 204aee74..6f3046cb 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -90,7 +90,31 @@ def register_routes(self, app: Any) -> None: if not isinstance(app, FastAPI): raise TypeError("app must be a FastAPI instance") - @app.post("/reset", response_model=ResetResponse) + @app.post( + "/reset", + response_model=ResetResponse, + tags=["Environment Control"], + summary="Reset the environment", + description=""" +Reset the environment to its initial state and return the first observation. + +You can optionally provide a seed for reproducibility and an episode_id for tracking. + """, + responses={ + 200: { + "description": "Environment reset successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "ready", "data": {}}, + "reward": None, + "done": False, + } + } + }, + } + }, + ) async def reset( request: ResetRequest = Body(default_factory=ResetRequest), ) -> ResetResponse: @@ -114,7 +138,56 @@ async def reset( observation = self.env.reset(**valid_kwargs) return ResetResponse(**self._serialize_observation(observation)) - @app.post("/step", response_model=StepResponse) + @app.post( + "/step", + response_model=StepResponse, + tags=["Environment Control"], + summary="Execute an action in the environment", + description=""" +Execute an action in the environment and receive the resulting observation. + +The action must conform to the environment's action schema, which can be +retrieved from the `/schema/action` endpoint. If the action is invalid, +the endpoint will return HTTP 422 with detailed validation errors. + +The response includes: +- **observation**: The environment's response to the action +- **reward**: Optional reward signal (float or None) +- **done**: Boolean indicating if the episode has terminated + """, + responses={ + 200: { + "description": "Action executed successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "success", "data": {}}, + "reward": 1.0, + "done": False, + } + } + }, + }, + 422: { + "description": "Validation error - invalid action format or values", + "content": { + "application/json": { + "example": { + "detail": [ + { + "type": "string_too_short", + "loc": ["body", "action", "message"], + "msg": "String should have at least 1 character", + "input": "", + } + ] + } + } + }, + }, + 500: {"description": "Internal server error during action execution"}, + }, + ) async def step(request: StepRequest) -> StepResponse: """Step endpoint - executes action and returns observation.""" action_data = request.action @@ -130,7 +203,7 @@ async def step(request: StepRequest) -> StepResponse: # Handle optional parameters # Start with all fields from the request, including extra ones, but exclude 'action' - kwargs = request.model_dump(exclude_unset=True, exclude={'action'}) + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) # Pass arguments only if environment accepts them sig = inspect.signature(self.env.step) @@ -150,17 +223,45 @@ async def step(request: StepRequest) -> StepResponse: # Return serialized observation return StepResponse(**self._serialize_observation(observation)) - @app.get("/state", response_model=State) + @app.get( + "/state", + response_model=State, + tags=["State Management"], + summary="Get current environment state", + description=""" +Retrieve the current internal state of the environment. + +This endpoint allows inspection of the environment state without modifying it. +The structure of the state object is defined by the environment's State model. + """, + ) async def get_state() -> State: """State endpoint - returns current environment state.""" return self.env.state - @app.get("/health") + @app.get( + "/health", + tags=["Health"], + summary="Health check", + description="Check if the environment server is running and healthy.", + ) async def health() -> Dict[str, str]: """Health check endpoint.""" return {"status": "healthy"} - @app.get("/schema/action", tags=["Schema"]) + @app.get( + "/schema/action", + tags=["Schema"], + summary="Get action JSON schema", + description=""" +Get JSON schema for actions accepted by this environment. + +Returns the complete JSON schema definition for the Action model, +including all field types, constraints, and validation rules. +This schema can be used to validate actions before sending them +to the environment, or to generate forms in web interfaces. + """, + ) async def get_action_schema() -> Dict[str, Any]: """ Get JSON schema for actions accepted by this environment. @@ -175,7 +276,18 @@ async def get_action_schema() -> Dict[str, Any]: """ return self.action_cls.model_json_schema() - @app.get("/schema/observation", tags=["Schema"]) + @app.get( + "/schema/observation", + tags=["Schema"], + summary="Get observation JSON schema", + description=""" +Get JSON schema for observations returned by this environment. + +Returns the complete JSON schema definition for the Observation model, +including all field types and nested structures. This schema describes +what observations the environment will return after actions are executed. + """, + ) async def get_observation_schema() -> Dict[str, Any]: """ Get JSON schema for observations returned by this environment. @@ -189,7 +301,18 @@ async def get_observation_schema() -> Dict[str, Any]: """ return self.observation_cls.model_json_schema() - @app.get("/schema/state", tags=["Schema"]) + @app.get( + "/schema/state", + tags=["Schema"], + summary="Get state JSON schema", + description=""" +Get JSON schema for environment state objects. + +Returns the complete JSON schema definition for the State model. +This schema describes the internal state representation of the +environment, which can be queried via the /state endpoint. + """, + ) async def get_state_schema() -> Dict[str, Any]: """ Get JSON schema for environment state objects. @@ -305,26 +428,7 @@ def create_fastapi_app( action_cls: Type[Action], observation_cls: Type[Observation], ) -> Any: - """ - Create a FastAPI application with routes for the given environment. - - Args: - env: The Environment instance to serve - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - - Returns: - FastAPI application instance with routes registered - - Example: - >>> from envs.coding_env.server import CodeExecutionEnvironment - >>> from envs.coding_env.models import CodeAction, CodeObservation - >>> - >>> env = CodeExecutionEnvironment() - >>> app = create_fastapi_app(env, CodeAction, CodeObservation) - >>> - >>> # Run with: uvicorn module:app --host 0.0.0.0 --port 8000 - """ + """Create a FastAPI application with comprehensive documentation.""" try: from fastapi import FastAPI except ImportError: @@ -332,7 +436,62 @@ def create_fastapi_app( "FastAPI is required. Install with: pip install fastapi uvicorn" ) - app = FastAPI(title="Environment HTTP Server") + app = FastAPI( + title="OpenEnv Environment HTTP API", + version="1.0.0", + description=""" +# OpenEnv Environment HTTP API + +HTTP API for interacting with OpenEnv environments through a standardized interface. + +## Features + +* **Environment Reset**: Initialize or restart episodes +* **Action Execution**: Send actions and receive observations +* **State Inspection**: Query current environment state +* **Schema Access**: Retrieve JSON schemas for actions and observations + +## Workflow + +1. Call `/reset` to start a new episode and get initial observation +2. Call `/step` repeatedly with actions to interact with environment +3. Episode ends when observation returns `done: true` +4. Call `/state` anytime to inspect current environment state + +## Documentation + +* **Swagger UI**: Available at `/docs` +* **ReDoc**: Available at `/redoc` +* **OpenAPI Schema**: Available at `/openapi.json` + """, + openapi_tags=[ + { + "name": "Environment Control", + "description": "Core operations for environment interaction (reset, step)", + }, + { + "name": "State Management", + "description": "Operations for inspecting environment state", + }, + { + "name": "Schema", + "description": "JSON Schema endpoints for actions, observations, and state", + }, + {"name": "Health", "description": "Service health and status checks"}, + ], + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + contact={ + "name": "OpenEnv Team", + "url": "https://github.com/meta-pytorch/OpenEnv", + }, + license_info={ + "name": "BSD-3-Clause", + "url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE", + }, + ) + server = HTTPEnvServer(env, action_cls, observation_cls) server.register_routes(app) return app diff --git a/src/core/env_server/web_interface.py b/src/core/env_server/web_interface.py index c9f899a5..d1ce374f 100644 --- a/src/core/env_server/web_interface.py +++ b/src/core/env_server/web_interface.py @@ -1312,184 +1312,112 @@ def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> s def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: """Extract enhanced field metadata from Action class for form generation.""" + # Use Pydantic's JSON schema generation for robust metadata extraction + try: + schema = action_cls.model_json_schema() + except AttributeError: + # Fallback for non-Pydantic v2 models or if something goes wrong + return [] + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) action_fields = [] - if not hasattr(action_cls, "model_fields"): - return action_fields - for field_name, field_info in action_cls.model_fields.items(): + for field_name, field_info in properties.items(): if field_name == "metadata": continue - field_type = field_info.annotation - field_metadata = _extract_field_metadata(field_name, field_info) - - # Determine input type based on field type - input_type = _determine_input_type(field_type) + # JSON schema "type" can be a string or list/undefined + # Determine our internal input type + input_type = _determine_input_type_from_schema(field_info, field_name) - # Check if field is required - is_required = field_info.is_required() + is_required = field_name in required_fields action_fields.append( { "name": field_name, "type": input_type, "required": is_required, - "description": field_metadata.get("description", ""), - "default_value": field_metadata.get("default_value"), - "choices": field_metadata.get("choices", []), - "min_value": field_metadata.get("min_value"), - "max_value": field_metadata.get("max_value"), - "placeholder": field_metadata.get("placeholder", ""), - "help_text": field_metadata.get("help_text", ""), + "description": field_info.get("description", ""), + "default_value": field_info.get("default"), + "choices": field_info.get("enum"), + "min_value": field_info.get("minimum"), + "max_value": field_info.get("maximum"), + "min_length": field_info.get("minLength"), + "max_length": field_info.get("maxLength"), + "pattern": field_info.get("pattern"), + "placeholder": _generate_placeholder(field_name, field_info), + "help_text": _generate_help_text(field_name, field_info), } ) return action_fields - for field_name, field_info in action_cls.__dataclass_fields__.items(): - if field_name == "metadata": - continue - - field_type = field_info.type - field_metadata = _extract_field_metadata(field_name, field_info) - - # Determine input type based on field type - input_type = _determine_input_type(field_type) - # Check if field is required - is_required = field_info.default is field_info.default_factory - - action_fields.append( - { - "name": field_name, - "type": input_type, - "required": is_required, - "description": field_metadata.get("description", ""), - "default_value": field_metadata.get("default_value"), - "choices": field_metadata.get("choices", []), - "min_value": field_metadata.get("min_value"), - "max_value": field_metadata.get("max_value"), - "placeholder": field_metadata.get("placeholder", ""), - "help_text": field_metadata.get("help_text", ""), - } - ) - - return action_fields +def _determine_input_type_from_schema( + field_info: Dict[str, Any], field_name: str +) -> str: + """Determine the appropriate HTML input type from JSON schema info.""" + schema_type = field_info.get("type") + # Check for specific tensor field convention + if "tokens" in field_name.lower(): + return "tensor" -def _extract_field_metadata(field_name: str, field_info) -> Dict[str, Any]: - """Extract metadata from Pydantic field including description and type hints.""" - from typing import get_origin, get_args, Literal, Union + if "enum" in field_info: + return "select" - metadata = {} + if schema_type == "boolean": + return "checkbox" - # Extract description from Pydantic field description - if hasattr(field_info, "description") and field_info.description: - metadata["description"] = field_info.description + if schema_type == "integer" or schema_type == "number": + return "number" - # Extract default value - if hasattr(field_info, "default") and field_info.default is not None: - metadata["default_value"] = field_info.default + if schema_type == "string": + # Check if it should be a textarea + if ( + field_info.get("maxLength", 0) > 100 + or "message" in field_name.lower() + or "code" in field_name.lower() + ): + return "textarea" + return "text" - # Extract type information - field_type = field_info.annotation - origin = get_origin(field_type) + # Default fallback + return "text" - # Handle Literal types for dropdown choices - if origin is Literal: - args = get_args(field_type) - metadata["choices"] = list(args) - # Handle Optional types - if origin is Union: - args = get_args(field_type) - if len(args) == 2 and type(None) in args: - # This is Optional[SomeType] - non_none_type = args[0] if args[1] is type(None) else args[1] - metadata["optional"] = True - # Recursively check non-None type for choices - if get_origin(non_none_type) is Literal: - metadata["choices"] = list(get_args(non_none_type)) - else: - # Regular Union type - metadata["choices"] = [str(arg) for arg in args if arg is not type(None)] - - # Handle numeric constraints from Pydantic field - if hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra: - # Extract constraints from json_schema_extra if available - schema_extra = field_info.json_schema_extra - if "ge" in schema_extra: - metadata["min_value"] = schema_extra["ge"] - if "le" in schema_extra: - metadata["max_value"] = schema_extra["le"] - - # Handle numeric constraints based on type - if field_type in (int, float): - # Check for common constraint patterns in field name - if "count" in field_name.lower() or "num" in field_name.lower(): - metadata.setdefault("min_value", 0) - if "id" in field_name.lower(): - metadata.setdefault("min_value", 0) - - # Generate placeholder text +def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate placeholder text.""" if "message" in field_name.lower(): - metadata["placeholder"] = f"Enter {field_name.replace('_', ' ')}..." + return f"Enter {field_name.replace('_', ' ')}..." elif "code" in field_name.lower(): - metadata["placeholder"] = "Enter Python code here..." + return "Enter Python code here..." elif "tokens" in field_name.lower(): - metadata["placeholder"] = "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" + return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" else: - metadata["placeholder"] = f"Enter {field_name.replace('_', ' ')}..." + return f"Enter {field_name.replace('_', ' ')}..." + + +def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate help text.""" + description = field_info.get("description", "") + if description: + return description - # Generate help text based on field name and type if "action_id" in field_name.lower(): - metadata["help_text"] = "The action ID to execute in environment" + return "The action ID to execute in environment" elif "game_name" in field_name.lower(): - metadata["help_text"] = "Name of game or environment" + return "Name of game or environment" elif "tokens" in field_name.lower(): - metadata["help_text"] = "Token IDs as a comma-separated list of integers" + return "Token IDs as a comma-separated list of integers" elif "code" in field_name.lower(): - metadata["help_text"] = "Python code to execute in environment" + return "Python code to execute in environment" elif "message" in field_name.lower(): - metadata["help_text"] = "Text message to send" - - return metadata - - -def _determine_input_type(field_type) -> str: - """Determine the appropriate HTML input type for a field type.""" - from typing import get_origin, get_args, Literal, Union - - # Handle direct types - if field_type is str: - return "text" - elif field_type is int: - return "number" - elif field_type is float: - return "number" - elif field_type is bool: - return "checkbox" + return "Text message to send" - # Handle complex types - origin = get_origin(field_type) - - if origin is Literal: - return "select" - elif origin is Union: - args = get_args(field_type) - if len(args) == 2 and type(None) in args: - # Optional type - use the non-None type - non_none_type = args[0] if args[1] is type(None) else args[1] - return _determine_input_type(non_none_type) - elif all(isinstance(arg, str) for arg in args if arg is not type(None)): - return "select" - else: - return "text" - elif hasattr(field_type, "__name__") and "Tensor" in field_type.__name__: - return "tensor" - else: - return "text" + return "" def _markdown_to_html(markdown: str) -> str: @@ -1615,6 +1543,9 @@ def _generate_single_field(field: Dict[str, Any]) -> str: min_value = field.get("min_value") max_value = field.get("max_value") default_value = field.get("default_value") + min_length = field.get("min_length") + max_length = field.get("max_length") + pattern = field.get("pattern") # Build label with required indicator label_text = field_name.replace("_", " ").title() @@ -1631,16 +1562,23 @@ def _generate_single_field(field: Dict[str, Any]) -> str: input_attrs.append(f'min="{min_value}"') if max_value is not None: input_attrs.append(f'max="{max_value}"') + if min_length is not None: + input_attrs.append(f'minlength="{min_length}"') + if max_length is not None: + input_attrs.append(f'maxlength="{max_length}"') + if pattern is not None: + input_attrs.append(f'pattern="{pattern}"') if default_value is not None: input_attrs.append(f'value="{default_value}"') attrs_str = " ".join(input_attrs) if field_type == "checkbox": + checked = "checked" if default_value is True else "" return f'''
    {f'{help_text}' if help_text else ""} @@ -1677,13 +1615,11 @@ def _generate_single_field(field: Dict[str, Any]) -> str:
    ''' - elif field_type == "text" and ( - "message" in field_name.lower() or "code" in field_name.lower() - ): + elif field_type == "textarea": return f'''
    - + {f'{help_text}' if help_text else ""}
    ''' From a9038dc11686057d303b1ddf15ee5ad197844d44 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:22:41 +0530 Subject: [PATCH 09/11] feat: env metadata --- src/core/env_server/http_server.py | 39 +++++++++++++++++++++++++++--- src/core/env_server/interfaces.py | 18 +++++++++++++- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 6f3046cb..0cd16417 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -31,6 +31,7 @@ State, StepRequest, StepResponse, + EnvironmentMetadata, ) @@ -135,7 +136,11 @@ async def reset( if k in sig.parameters or has_kwargs: valid_kwargs[k] = v - observation = self.env.reset(**valid_kwargs) + # Run synchronous reset in thread pool to avoid blocking event loop + loop = asyncio.get_event_loop() + observation = await loop.run_in_executor( + self._executor, lambda: self.env.reset(**valid_kwargs) + ) return ResetResponse(**self._serialize_observation(observation)) @app.post( @@ -217,8 +222,11 @@ async def step(request: StepRequest) -> StepResponse: if k in sig.parameters or has_kwargs: valid_kwargs[k] = v - # Execute step - observation = self.env.step(action, **valid_kwargs) + # Run synchronous step in thread pool to avoid blocking event loop + loop = asyncio.get_event_loop() + observation = await loop.run_in_executor( + self._executor, lambda: self.env.step(action, **valid_kwargs) + ) # Return serialized observation return StepResponse(**self._serialize_observation(observation)) @@ -239,6 +247,27 @@ async def get_state() -> State: """State endpoint - returns current environment state.""" return self.env.state + @app.get( + "/metadata", + response_model=EnvironmentMetadata, + tags=["Environment Info"], + summary="Get environment metadata", + description=""" +Get metadata about this environment. + +Returns information about the environment including name, description, +version, author, and documentation links. + """, + ) + async def get_metadata() -> EnvironmentMetadata: + """ + Get metadata about this environment. + + Returns information about the environment including name, description, + version, author, and documentation links. + """ + return self.env.get_metadata() + @app.get( "/health", tags=["Health"], @@ -473,6 +502,10 @@ def create_fastapi_app( "name": "State Management", "description": "Operations for inspecting environment state", }, + { + "name": "Environment Info", + "description": "Information about the environment", + }, { "name": "Schema", "description": "JSON Schema endpoints for actions, observations, and state", diff --git a/src/core/env_server/interfaces.py b/src/core/env_server/interfaces.py index afcbdde9..b438cd66 100644 --- a/src/core/env_server/interfaces.py +++ b/src/core/env_server/interfaces.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional, Protocol, TypedDict -from .types import Action, Observation, State +from .types import Action, Observation, State, EnvironmentMetadata class Message(TypedDict): @@ -121,6 +121,22 @@ def state(self) -> State: """Get the current environment state.""" pass + def get_metadata(self) -> EnvironmentMetadata: + """ + Get metadata about this environment. + + Override this method to provide custom metadata for the environment. + Default implementation returns basic metadata derived from class name. + + Returns: + EnvironmentMetadata with environment information + """ + return EnvironmentMetadata( + name=self.__class__.__name__, + description=f"{self.__class__.__name__} environment", + version="1.0.0", + ) + def _apply_transform(self, observation: Observation) -> Observation: """Apply transform if one is provided.""" if self.transform is not None: From bbf9252b2c5dab9c0f1a63bf777a2f91597f388b Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Tue, 25 Nov 2025 08:54:52 +0000 Subject: [PATCH 10/11] feat: serialization utilities and route config --- src/core/env_server/__init__.py | 15 +- src/core/env_server/http_server.py | 381 +++++++++++---------------- src/core/env_server/route_config.py | 60 +++++ src/core/env_server/serialization.py | 139 ++++++++++ src/core/env_server/types.py | 19 ++ src/core/env_server/web_interface.py | 75 ++---- 6 files changed, 403 insertions(+), 286 deletions(-) create mode 100644 src/core/env_server/route_config.py create mode 100644 src/core/env_server/serialization.py diff --git a/src/core/env_server/__init__.py b/src/core/env_server/__init__.py index 79e66535..a5401cca 100644 --- a/src/core/env_server/__init__.py +++ b/src/core/env_server/__init__.py @@ -9,7 +9,13 @@ from .base_transforms import CompositeTransform, NullTransform from .http_server import HTTPEnvServer, create_app, create_fastapi_app from .interfaces import Environment, Message, ModelTokenizer, Transform -from .types import Action, Observation, State +from .route_config import GetEndpointConfig +from .serialization import ( + deserialize_action, + deserialize_action_with_preprocessing, + serialize_observation, +) +from .types import Action, Observation, State, SchemaResponse from .web_interface import create_web_interface_app, WebInterfaceManager __all__ = [ @@ -22,6 +28,7 @@ "Action", "Observation", "State", + "SchemaResponse", # Base transforms "CompositeTransform", "NullTransform", @@ -32,4 +39,10 @@ # Web Interface "create_web_interface_app", "WebInterfaceManager", + # Serialization utilities + "deserialize_action", + "deserialize_action_with_preprocessing", + "serialize_observation", + # Route configuration + "GetEndpointConfig", ] diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 0cd16417..e7267afe 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -23,6 +23,11 @@ from pydantic import ValidationError from .interfaces import Environment +from .route_config import ( + GetEndpointConfig, + register_get_endpoints, +) +from .serialization import deserialize_action, serialize_observation from .types import ( Action, Observation, @@ -32,6 +37,7 @@ StepRequest, StepResponse, EnvironmentMetadata, + SchemaResponse, ) @@ -80,6 +86,29 @@ def __init__( # This is needed for environments using sync libraries (e.g., Playwright sync API) self._executor = ThreadPoolExecutor(max_workers=1) + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + """Run a synchronous function in the thread pool executor.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) + + def _get_valid_kwargs(self, sig, kwargs, skip_params=None): + """Filter kwargs to only include parameters accepted by the function signature.""" + if skip_params is None: + skip_params = set() + + valid_kwargs = {} + + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + + for k, v in kwargs.items(): + if k in sig.parameters or has_kwargs: + if k not in skip_params: + valid_kwargs[k] = v + + return valid_kwargs + def register_routes(self, app: Any) -> None: """ Register HTTP routes on a FastAPI application. @@ -91,6 +120,56 @@ def register_routes(self, app: Any) -> None: if not isinstance(app, FastAPI): raise TypeError("app must be a FastAPI instance") + # Helper function to handle reset endpoint + async def reset_handler( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + """Reset endpoint - returns initial observation.""" + # Handle optional parameters + # Start with all fields from the request, including extra ones + kwargs = request.model_dump(exclude_unset=True) + + # Pass arguments only if environment accepts them + sig = inspect.signature(self.env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + # Run synchronous reset in thread pool to avoid blocking event loop + observation = await self._run_sync_in_thread_pool( + self.env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + + # Helper function to handle step endpoint + async def step_handler(request: StepRequest) -> StepResponse: + """Step endpoint - executes action and returns observation.""" + action_data = request.action + + # Deserialize action with Pydantic validation + try: + action = deserialize_action(action_data, self.action_cls) + except ValidationError as e: + # Return HTTP 422 with detailed validation errors + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() + ) + + # Handle optional parameters + # Start with all fields from the request, including extra ones, but exclude 'action' + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + + # Pass arguments only if environment accepts them + sig = inspect.signature(self.env.step) + valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) + + # Run synchronous step in thread pool to avoid blocking event loop + observation = await self._run_sync_in_thread_pool( + self.env.step, action, **valid_kwargs + ) + + # Return serialized observation + return StepResponse(**serialize_observation(observation)) + + # Register routes using the helpers @app.post( "/reset", response_model=ResetResponse, @@ -119,29 +198,7 @@ def register_routes(self, app: Any) -> None: async def reset( request: ResetRequest = Body(default_factory=ResetRequest), ) -> ResetResponse: - """Reset endpoint - returns initial observation.""" - # Handle optional parameters - # Start with all fields from the request, including extra ones - kwargs = request.model_dump(exclude_unset=True) - - # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.reset) - valid_kwargs = {} - - has_kwargs = any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() - ) - - for k, v in kwargs.items(): - if k in sig.parameters or has_kwargs: - valid_kwargs[k] = v - - # Run synchronous reset in thread pool to avoid blocking event loop - loop = asyncio.get_event_loop() - observation = await loop.run_in_executor( - self._executor, lambda: self.env.reset(**valid_kwargs) - ) - return ResetResponse(**self._serialize_observation(observation)) + return await reset_handler(request) @app.post( "/step", @@ -152,7 +209,7 @@ async def reset( Execute an action in the environment and receive the resulting observation. The action must conform to the environment's action schema, which can be -retrieved from the `/schema/action` endpoint. If the action is invalid, +retrieved from the `/schema` endpoint. If the action is invalid, the endpoint will return HTTP 422 with detailed validation errors. The response includes: @@ -194,223 +251,95 @@ async def reset( }, ) async def step(request: StepRequest) -> StepResponse: - """Step endpoint - executes action and returns observation.""" - action_data = request.action - - # Deserialize action with Pydantic validation - try: - action = self._deserialize_action(action_data) - except ValidationError as e: - # Return HTTP 422 with detailed validation errors - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors() - ) - - # Handle optional parameters - # Start with all fields from the request, including extra ones, but exclude 'action' - kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) - - # Pass arguments only if environment accepts them - sig = inspect.signature(self.env.step) - valid_kwargs = {} - - has_kwargs = any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() - ) - - for k, v in kwargs.items(): - if k in sig.parameters or has_kwargs: - valid_kwargs[k] = v - - # Run synchronous step in thread pool to avoid blocking event loop - loop = asyncio.get_event_loop() - observation = await loop.run_in_executor( - self._executor, lambda: self.env.step(action, **valid_kwargs) - ) - - # Return serialized observation - return StepResponse(**self._serialize_observation(observation)) - - @app.get( - "/state", - response_model=State, - tags=["State Management"], - summary="Get current environment state", - description=""" + return await step_handler(request) + + # Configure and register GET endpoints declaratively + get_endpoints = [ + GetEndpointConfig( + path="/state", + handler=lambda: self.env.state, + response_model=State, + tag="State Management", + summary="Get current environment state", + description=""" Retrieve the current internal state of the environment. This endpoint allows inspection of the environment state without modifying it. The structure of the state object is defined by the environment's State model. - """, - ) - async def get_state() -> State: - """State endpoint - returns current environment state.""" - return self.env.state - - @app.get( - "/metadata", - response_model=EnvironmentMetadata, - tags=["Environment Info"], - summary="Get environment metadata", - description=""" + """, + ), + GetEndpointConfig( + path="/metadata", + handler=self.env.get_metadata, + response_model=EnvironmentMetadata, + tag="Environment Info", + summary="Get environment metadata", + description=""" Get metadata about this environment. Returns information about the environment including name, description, version, author, and documentation links. - """, - ) - async def get_metadata() -> EnvironmentMetadata: - """ - Get metadata about this environment. - - Returns information about the environment including name, description, - version, author, and documentation links. - """ - return self.env.get_metadata() - + """, + ), + GetEndpointConfig( + path="/health", + handler=lambda: {"status": "healthy"}, + response_model=Dict[str, str], + tag="Health", + summary="Health check", + description="Check if the environment server is running and healthy.", + ), + ] + register_get_endpoints(app, get_endpoints) + + # Register combined schema endpoint @app.get( - "/health", - tags=["Health"], - summary="Health check", - description="Check if the environment server is running and healthy.", - ) - async def health() -> Dict[str, str]: - """Health check endpoint.""" - return {"status": "healthy"} - - @app.get( - "/schema/action", + "/schema", + response_model=SchemaResponse, tags=["Schema"], - summary="Get action JSON schema", + summary="Get all JSON schemas", description=""" -Get JSON schema for actions accepted by this environment. - -Returns the complete JSON schema definition for the Action model, -including all field types, constraints, and validation rules. -This schema can be used to validate actions before sending them -to the environment, or to generate forms in web interfaces. - """, - ) - async def get_action_schema() -> Dict[str, Any]: - """ - Get JSON schema for actions accepted by this environment. +Get JSON schemas for actions, observations, and state in a single response. - Returns the complete JSON schema definition for the Action model, - including all field types, constraints, and validation rules. - This schema can be used to validate actions before sending them - to the environment, or to generate forms in web interfaces. +Returns a combined schema object containing: +- **action**: JSON schema for actions accepted by this environment +- **observation**: JSON schema for observations returned by this environment +- **state**: JSON schema for environment state objects - Returns: - Dict containing JSON Schema - """ - return self.action_cls.model_json_schema() - - @app.get( - "/schema/observation", - tags=["Schema"], - summary="Get observation JSON schema", - description=""" -Get JSON schema for observations returned by this environment. - -Returns the complete JSON schema definition for the Observation model, -including all field types and nested structures. This schema describes -what observations the environment will return after actions are executed. +This is more efficient than calling individual schema endpoints and provides +all schema information needed to interact with the environment. """, + responses={ + 200: { + "description": "Combined schemas retrieved successfully", + "content": { + "application/json": { + "example": { + "action": { + "type": "object", + "properties": {"message": {"type": "string"}}, + }, + "observation": { + "type": "object", + "properties": {"response": {"type": "string"}}, + }, + "state": { + "type": "object", + "properties": {"step_count": {"type": "integer"}}, + }, + } + } + }, + } + }, ) - async def get_observation_schema() -> Dict[str, Any]: - """ - Get JSON schema for observations returned by this environment. - - Returns the complete JSON schema definition for the Observation model, - including all field types and nested structures. This schema describes - what observations the environment will return after actions are executed. - - Returns: - Dict containing JSON Schema - """ - return self.observation_cls.model_json_schema() - - @app.get( - "/schema/state", - tags=["Schema"], - summary="Get state JSON schema", - description=""" -Get JSON schema for environment state objects. - -Returns the complete JSON schema definition for the State model. -This schema describes the internal state representation of the -environment, which can be queried via the /state endpoint. - """, - ) - async def get_state_schema() -> Dict[str, Any]: - """ - Get JSON schema for environment state objects. - - Returns the complete JSON schema definition for the State model. - This schema describes the internal state representation of the - environment, which can be queried via the /state endpoint. - - Returns: - Dict containing JSON Schema - """ - return State.model_json_schema() - - def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: - """ - Convert JSON dict to Action instance using Pydantic validation. - - Args: - action_data: Dictionary containing action data - - Returns: - Action instance - - Raises: - ValidationError: If action_data is invalid for the action class - - Note: - This uses Pydantic's model_validate() for automatic validation. - """ - # Pydantic handles validation automatically - action = self.action_cls.model_validate(action_data) - return action - - def _serialize_observation(self, observation: Observation) -> Dict[str, Any]: - """ - Convert Observation instance to JSON-compatible dict using Pydantic. - - Args: - observation: Observation instance - - Returns: - Dictionary compatible with HTTPEnvClient._parse_result() - - The format matches what HTTPEnvClient expects: - { - "observation": {...}, # Observation fields - "reward": float | None, - "done": bool, - } - """ - # Use Pydantic's model_dump() for serialization - obs_dict = observation.model_dump( - exclude={ - "reward", - "done", - "metadata", - } # Exclude these from observation dict - ) - - # Extract reward and done directly from the observation - reward = observation.reward - done = observation.done - - # Return in HTTPEnvClient expected format - return { - "observation": obs_dict, - "reward": reward, - "done": done, - } + async def get_schemas() -> SchemaResponse: + """Return all schemas in one response.""" + return SchemaResponse( + action=self.action_cls.model_json_schema(), + observation=self.observation_cls.model_json_schema(), + state=State.model_json_schema(), + ) def create_app( diff --git a/src/core/env_server/route_config.py b/src/core/env_server/route_config.py new file mode 100644 index 00000000..a429bbb3 --- /dev/null +++ b/src/core/env_server/route_config.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Route configuration utilities for declarative FastAPI route registration. + +This module provides utilities to reduce boilerplate in route registration +by using configuration objects instead of repeated function calls. +""" + +from dataclasses import dataclass +from typing import Callable, List, Type, TypeVar + +from fastapi import FastAPI +from pydantic import BaseModel + +# TypeVar for generic response types +T = TypeVar("T", bound=BaseModel) + + +@dataclass +class GetEndpointConfig: + """Configuration for a simple GET endpoint.""" + + path: str + handler: Callable[[], BaseModel | dict] + response_model: Type[BaseModel] | Type[dict] + tag: str + summary: str + description: str + + +def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None: + """ + Register multiple GET endpoints from configuration. + + Args: + app: FastAPI application instance + configs: List of GET endpoint configurations + """ + for config in configs: + # Capture handler in a closure to avoid non-serializable default parameter + def make_endpoint( + handler: Callable[[], BaseModel | dict], + ) -> Callable[[], BaseModel | dict]: + async def endpoint() -> BaseModel | dict: + return handler() + + return endpoint + + app.get( + config.path, + response_model=config.response_model, + tags=[config.tag], + summary=config.summary, + description=config.description, + )(make_endpoint(config.handler)) diff --git a/src/core/env_server/serialization.py b/src/core/env_server/serialization.py new file mode 100644 index 00000000..a97a0528 --- /dev/null +++ b/src/core/env_server/serialization.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared serialization and deserialization utilities for OpenEnv HTTP servers. + +This module provides common utilities for converting between JSON dictionaries +and Pydantic models (Action/Observation) to eliminate code duplication across +HTTP server and web interface implementations. +""" + +from typing import Any, Dict, Type + +from .types import Action, Observation + + +def deserialize_action( + action_data: Dict[str, Any], action_cls: Type[Action] +) -> Action: + """ + Convert JSON dict to Action instance using Pydantic validation. + + This is a basic deserialization that works for most environments. + For special cases (e.g., tensor fields, custom type conversions), + use deserialize_action_with_preprocessing(). + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + + Note: + This uses Pydantic's model_validate() for automatic validation. + """ + return action_cls.model_validate(action_data) + + +def deserialize_action_with_preprocessing( + action_data: Dict[str, Any], action_cls: Type[Action] +) -> Action: + """ + Convert JSON dict to Action instance with preprocessing for special types. + + This version handles common type conversions needed for web interfaces: + - Converting lists/strings to tensors for 'tokens' field + - Converting string action_id to int + - Other custom preprocessing as needed + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + """ + processed_data = {} + + for key, value in action_data.items(): + if key == "tokens" and isinstance(value, (list, str)): + # Convert list or string to tensor + if isinstance(value, str): + # If it's a string, try to parse it as a list of numbers + try: + import json + + value = json.loads(value) + except Exception: + # If parsing fails, treat as empty list + value = [] + if isinstance(value, list): + try: + import torch + + processed_data[key] = torch.tensor(value, dtype=torch.long) + except ImportError: + # If torch not available, keep as list + processed_data[key] = value + else: + processed_data[key] = value + elif key == "action_id" and isinstance(value, str): + # Convert action_id from string to int + try: + processed_data[key] = int(value) + except ValueError: + # If conversion fails, keep original value + processed_data[key] = value + else: + processed_data[key] = value + + return action_cls.model_validate(processed_data) + + +def serialize_observation(observation: Observation) -> Dict[str, Any]: + """ + Convert Observation instance to JSON-compatible dict using Pydantic. + + Args: + observation: Observation instance + + Returns: + Dictionary compatible with HTTPEnvClient._parse_result() + + The format matches what HTTPEnvClient expects: + { + "observation": {...}, # Observation fields + "reward": float | None, + "done": bool, + } + """ + # Use Pydantic's model_dump() for serialization + obs_dict = observation.model_dump( + exclude={ + "reward", + "done", + "metadata", + } # Exclude these from observation dict + ) + + # Extract reward and done directly from the observation + reward = observation.reward + done = observation.done + + # Return in HTTPEnvClient expected format + return { + "observation": obs_dict, + "reward": reward, + "done": done, + } diff --git a/src/core/env_server/types.py b/src/core/env_server/types.py index d96d7baf..8d63f7d7 100644 --- a/src/core/env_server/types.py +++ b/src/core/env_server/types.py @@ -182,3 +182,22 @@ class EnvironmentMetadata(BaseModel): documentation_url: Optional[str] = Field( default=None, description="URL to the environment's documentation" ) + + +class SchemaResponse(BaseModel): + """Response model for the combined schema endpoint.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + action: Dict[str, Any] = Field( + description="JSON schema for actions accepted by this environment" + ) + observation: Dict[str, Any] = Field( + description="JSON schema for observations returned by this environment" + ) + state: Dict[str, Any] = Field( + description="JSON schema for environment state objects" + ) diff --git a/src/core/env_server/web_interface.py b/src/core/env_server/web_interface.py index d1ce374f..b370cfa5 100644 --- a/src/core/env_server/web_interface.py +++ b/src/core/env_server/web_interface.py @@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, ConfigDict from .interfaces import Environment +from .serialization import deserialize_action_with_preprocessing, serialize_observation from .types import Action, Observation, State, EnvironmentMetadata @@ -192,40 +193,40 @@ async def reset_environment(self) -> Dict[str, Any]: observation: Observation = self.env.reset() state: State = self.env.state + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + # Update episode state self.episode_state.episode_id = state.episode_id self.episode_state.step_count = 0 - self.episode_state.current_observation = observation.model_dump( - exclude={"reward", "done", "metadata"} - ) + self.episode_state.current_observation = serialized["observation"] self.episode_state.action_logs = [] self.episode_state.is_reset = True # Send state update await self._send_state_update() - return { - "observation": observation.model_dump( - exclude={"reward", "done", "metadata"} - ), - "reward": observation.reward, - "done": observation.done, - } + return serialized async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: """Execute a step in the environment and update state.""" - # Deserialize action - action: Action = self._deserialize_action(action_data) + # Deserialize action with preprocessing for web interface special cases + action: Action = deserialize_action_with_preprocessing( + action_data, self.action_cls + ) # Execute step observation: Observation = self.env.step(action) state: State = self.env.state + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + # Create action log action_log = ActionLog( timestamp=datetime.now().isoformat(), action=action.model_dump(exclude={"metadata"}), - observation=observation.model_dump(exclude={"reward", "done", "metadata"}), + observation=serialized["observation"], reward=observation.reward, done=observation.done, step_count=state.step_count, @@ -234,64 +235,20 @@ async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: # Update episode state self.episode_state.episode_id = state.episode_id self.episode_state.step_count = state.step_count - self.episode_state.current_observation = observation.model_dump( - exclude={"reward", "done", "metadata"} - ) + self.episode_state.current_observation = serialized["observation"] self.episode_state.action_logs.append(action_log) self.episode_state.is_reset = False # Send state update await self._send_state_update() - return { - "observation": observation.model_dump( - exclude={"reward", "done", "metadata"} - ), - "reward": observation.reward, - "done": observation.done, - } + return serialized def get_state(self) -> Dict[str, Any]: """Get current environment state.""" state: State = self.env.state return state.model_dump() - def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: - """Convert JSON dict to Action instance using Pydantic validation.""" - # Handle tensor fields that come from JSON as lists - processed_data = {} - for key, value in action_data.items(): - if key == "tokens" and isinstance(value, (list, str)): - # Convert list or string to tensor - if isinstance(value, str): - # If it's a string, try to parse it as a list of numbers - try: - import json - - value = json.loads(value) - except Exception: - # If parsing fails, treat as empty list - value = [] - if isinstance(value, list): - import torch - - processed_data[key] = torch.tensor(value, dtype=torch.long) - else: - processed_data[key] = value - elif key == "action_id" and isinstance(value, str): - # Convert action_id from string to int - try: - processed_data[key] = int(value) - except ValueError: - # If conversion fails, keep original value - processed_data[key] = value - else: - processed_data[key] = value - - # Use Pydantic's model_validate for automatic validation - action = self.action_cls.model_validate(processed_data) - return action - def create_web_interface_app( env: Environment, From c4f20d738bc78b1657162df1cfceb5351a3f2765 Mon Sep 17 00:00:00 2001 From: swappy <59965507+rycerzes@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:33:12 +0530 Subject: [PATCH 11/11] chore: types --- src/core/env_server/__init__.py | 3 ++- src/core/env_server/http_server.py | 16 +++++++--------- src/core/env_server/route_config.py | 7 ++----- src/core/env_server/types.py | 13 ++++++++++++- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/core/env_server/__init__.py b/src/core/env_server/__init__.py index a5401cca..4e1c2d7a 100644 --- a/src/core/env_server/__init__.py +++ b/src/core/env_server/__init__.py @@ -15,7 +15,7 @@ deserialize_action_with_preprocessing, serialize_observation, ) -from .types import Action, Observation, State, SchemaResponse +from .types import Action, Observation, State, SchemaResponse, HealthResponse from .web_interface import create_web_interface_app, WebInterfaceManager __all__ = [ @@ -29,6 +29,7 @@ "Observation", "State", "SchemaResponse", + "HealthResponse", # Base transforms "CompositeTransform", "NullTransform", diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index e7267afe..7fa7c0f3 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -17,7 +17,7 @@ import inspect import os from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, Optional, Type +from typing import Optional, Type from fastapi import Body, FastAPI, HTTPException, status from pydantic import ValidationError @@ -38,6 +38,7 @@ StepResponse, EnvironmentMetadata, SchemaResponse, + HealthResponse, ) @@ -109,7 +110,7 @@ def _get_valid_kwargs(self, sig, kwargs, skip_params=None): return valid_kwargs - def register_routes(self, app: Any) -> None: + def register_routes(self, app: FastAPI) -> None: """ Register HTTP routes on a FastAPI application. @@ -117,9 +118,6 @@ def register_routes(self, app: Any) -> None: app: FastAPI application instance """ - if not isinstance(app, FastAPI): - raise TypeError("app must be a FastAPI instance") - # Helper function to handle reset endpoint async def reset_handler( request: ResetRequest = Body(default_factory=ResetRequest), @@ -283,8 +281,8 @@ async def step(request: StepRequest) -> StepResponse: ), GetEndpointConfig( path="/health", - handler=lambda: {"status": "healthy"}, - response_model=Dict[str, str], + handler=lambda: HealthResponse(status="healthy"), + response_model=HealthResponse, tag="Health", summary="Health check", description="Check if the environment server is running and healthy.", @@ -347,7 +345,7 @@ def create_app( action_cls: Type[Action], observation_cls: Type[Observation], env_name: Optional[str] = None, -) -> Any: +) -> FastAPI: """ Create a FastAPI application with or without web interface. @@ -385,7 +383,7 @@ def create_fastapi_app( env: Environment, action_cls: Type[Action], observation_cls: Type[Observation], -) -> Any: +) -> FastAPI: """Create a FastAPI application with comprehensive documentation.""" try: from fastapi import FastAPI diff --git a/src/core/env_server/route_config.py b/src/core/env_server/route_config.py index a429bbb3..08807c68 100644 --- a/src/core/env_server/route_config.py +++ b/src/core/env_server/route_config.py @@ -12,14 +12,11 @@ """ from dataclasses import dataclass -from typing import Callable, List, Type, TypeVar +from typing import Callable, List, Type from fastapi import FastAPI from pydantic import BaseModel -# TypeVar for generic response types -T = TypeVar("T", bound=BaseModel) - @dataclass class GetEndpointConfig: @@ -27,7 +24,7 @@ class GetEndpointConfig: path: str handler: Callable[[], BaseModel | dict] - response_model: Type[BaseModel] | Type[dict] + response_model: Type[BaseModel] | type[dict] tag: str summary: str description: str diff --git a/src/core/env_server/types.py b/src/core/env_server/types.py index 8d63f7d7..c3ee689c 100644 --- a/src/core/env_server/types.py +++ b/src/core/env_server/types.py @@ -44,7 +44,7 @@ class Observation(BaseModel): ) done: bool = Field(default=False, description="Whether the episode has terminated") - reward: Union[bool, int, float, None] = Field( + reward: bool | int | float | None = Field( default=None, description="Reward signal from the last action" ) metadata: Dict[str, Any] = Field( @@ -201,3 +201,14 @@ class SchemaResponse(BaseModel): state: Dict[str, Any] = Field( description="JSON schema for environment state objects" ) + + +class HealthResponse(BaseModel): + """Response model for health check endpoint.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + status: str = Field(description="Health status of the environment server")