diff --git a/src/core/env_server/__init__.py b/src/core/env_server/__init__.py index 79e66535..4e1c2d7a 100644 --- a/src/core/env_server/__init__.py +++ b/src/core/env_server/__init__.py @@ -9,7 +9,13 @@ from .base_transforms import CompositeTransform, NullTransform from .http_server import HTTPEnvServer, create_app, create_fastapi_app from .interfaces import Environment, Message, ModelTokenizer, Transform -from .types import Action, Observation, State +from .route_config import GetEndpointConfig +from .serialization import ( + deserialize_action, + deserialize_action_with_preprocessing, + serialize_observation, +) +from .types import Action, Observation, State, SchemaResponse, HealthResponse from .web_interface import create_web_interface_app, WebInterfaceManager __all__ = [ @@ -22,6 +28,8 @@ "Action", "Observation", "State", + "SchemaResponse", + "HealthResponse", # Base transforms "CompositeTransform", "NullTransform", @@ -32,4 +40,10 @@ # Web Interface "create_web_interface_app", "WebInterfaceManager", + # Serialization utilities + "deserialize_action", + "deserialize_action_with_preprocessing", + "serialize_observation", + # Route configuration + "GetEndpointConfig", ] diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 207235f6..7fa7c0f3 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -1,257 +1,457 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -HTTP server wrapper for Environment instances. - -This module provides utilities to wrap any Environment subclass and expose it -over HTTP endpoints that HTTPEnvClient can consume. -""" - -from __future__ import annotations - -import asyncio -import os -from concurrent.futures import ThreadPoolExecutor -from dataclasses import asdict -from typing import Any, Dict, Type - -from .interfaces import Environment -from .types import Action, Observation -from fastapi import Body, FastAPI - -class HTTPEnvServer: - """ - HTTP server wrapper for Environment instances. - - This class wraps an Environment and exposes its reset(), step(), and state - methods as HTTP endpoints compatible with HTTPEnvClient. - - The server expects: - - Action deserialization: Converts JSON dict to Action subclass - - Observation serialization: Converts Observation subclass to JSON dict - - Example: - >>> from core.env_server import HTTPEnvServer - >>> from envs.coding_env.server import CodeExecutionEnvironment - >>> - >>> env = CodeExecutionEnvironment() - >>> server = HTTPEnvServer(env) - >>> - >>> # Register routes with FastAPI - >>> from fastapi import FastAPI - >>> app = FastAPI() - >>> server.register_routes(app) - """ - - def __init__( - self, - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - ): - """ - Initialize HTTP server wrapper. - - Args: - env: The Environment instance to wrap - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - """ - self.env = env - self.action_cls = action_cls - self.observation_cls = observation_cls - # Create thread pool for running sync code in async context - # This is needed for environments using sync libraries (e.g., Playwright sync API) - self._executor = ThreadPoolExecutor(max_workers=1) - - def register_routes(self, app: Any) -> None: - """ - Register HTTP routes on a FastAPI application. - - Args: - app: FastAPI application instance - """ - - if not isinstance(app, FastAPI): - raise TypeError("app must be a FastAPI instance") - - @app.post("/reset") - async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]: - """Reset endpoint - returns initial observation.""" - # TODO: Handle seed, episode_id from request if provided - # Run sync environment code in thread pool to avoid blocking asyncio loop - loop = asyncio.get_event_loop() - observation = await loop.run_in_executor(self._executor, self.env.reset) - return self._serialize_observation(observation) - - @app.post("/step") - async def step(request: Dict[str, Any]) -> Dict[str, Any]: - """Step endpoint - executes action and returns observation.""" - # Support both {"action": {...}} and direct action fields - action_data = request.get("action", request) - # TODO: Handle timeout_s, request_id, episode_id from request if provided - - # Deserialize action - action = self._deserialize_action(action_data) - - # Execute step in thread pool to avoid blocking asyncio loop - loop = asyncio.get_event_loop() - observation = await loop.run_in_executor( - self._executor, self.env.step, action - ) - - # Return serialized observation - return self._serialize_observation(observation) - - @app.get("/state") - async def get_state() -> Dict[str, Any]: - """State endpoint - returns current environment state.""" - state = self.env.state - return asdict(state) - - @app.get("/health") - async def health() -> Dict[str, str]: - """Health check endpoint.""" - return {"status": "healthy"} - - - def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: - """ - Convert JSON dict to Action instance. - - Args: - action_data: Dictionary containing action data - - Returns: - Action instance - - Note: - This is a simple implementation. Subclasses may need to override - for more complex deserialization logic. - """ - # Remove metadata if present (it will be set via kw_only field) - metadata = action_data.pop("metadata", {}) - action = self.action_cls(**action_data) - action.metadata = metadata - return action - - def _serialize_observation(self, observation: Observation) -> Dict[str, Any]: - """ - Convert Observation instance to JSON-compatible dict. - - Args: - observation: Observation instance - - Returns: - Dictionary compatible with HTTPEnvClient._parse_result() - - The format matches what HTTPEnvClient expects: - { - "observation": {...}, # Observation fields - "reward": float | None, - "done": bool, - } - """ - obs_dict = asdict(observation) - - # Convert numpy arrays to lists for JSON serialization - def _convert_numpy(obj): - """Recursively convert numpy arrays to lists.""" - if hasattr(obj, '__array__'): # numpy array - return obj.tolist() - elif isinstance(obj, dict): - return {k: _convert_numpy(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple)): - return type(obj)(_convert_numpy(item) for item in obj) - return obj - - obs_dict = _convert_numpy(obs_dict) - - # Extract reward and done (these are part of StepResult on client side) - reward = obs_dict.pop("reward", None) - done = obs_dict.pop("done", False) - obs_dict.pop("metadata", None) # Remove metadata from observation - - # Return in HTTPEnvClient expected format - return { - "observation": obs_dict, - "reward": reward, - "done": done, - } - -def create_app( - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - env_name: Optional[str] = None, -) -> Any: - """ - Create a FastAPI application with or without web interface. - - This function creates a FastAPI app with the web interface enabled by default, - including README integration for better user experience. - - Args: - env: The Environment instance to serve - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - env_name: Optional environment name for README loading - - Returns: - FastAPI application instance with or without web interface and README integration - """ - # Check if web interface should be enabled - # This can be controlled via environment variable or build argument - enable_web = ( - os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes") - ) - - if enable_web: - # Import web interface only when needed - from .web_interface import create_web_interface_app - return create_web_interface_app(env, action_cls, observation_cls, env_name) - else: - # Use standard FastAPI app without web interface - return create_fastapi_app(env, action_cls, observation_cls) - - -def create_fastapi_app( - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], -) -> Any: - """ - Create a FastAPI application with routes for the given environment. - - Args: - env: The Environment instance to serve - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - - Returns: - FastAPI application instance with routes registered - - Example: - >>> from envs.coding_env.server import CodeExecutionEnvironment - >>> from envs.coding_env.models import CodeAction, CodeObservation - >>> - >>> env = CodeExecutionEnvironment() - >>> app = create_fastapi_app(env, CodeAction, CodeObservation) - >>> - >>> # Run with: uvicorn module:app --host 0.0.0.0 --port 8000 - """ - try: - from fastapi import FastAPI - except ImportError: - raise ImportError( - "FastAPI is required. Install with: pip install fastapi uvicorn" - ) - - app = FastAPI(title="Environment HTTP Server") - server = HTTPEnvServer(env, action_cls, observation_cls) - server.register_routes(app) - return app +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +HTTP server wrapper for Environment instances. + +This module provides utilities to wrap any Environment subclass and expose it +over HTTP endpoints that HTTPEnvClient can consume. +""" + +from __future__ import annotations + +import asyncio +import inspect +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Type + +from fastapi import Body, FastAPI, HTTPException, status +from pydantic import ValidationError + +from .interfaces import Environment +from .route_config import ( + GetEndpointConfig, + register_get_endpoints, +) +from .serialization import deserialize_action, serialize_observation +from .types import ( + Action, + Observation, + ResetRequest, + ResetResponse, + State, + StepRequest, + StepResponse, + EnvironmentMetadata, + SchemaResponse, + HealthResponse, +) + + +class HTTPEnvServer: + """ + HTTP server wrapper for Environment instances. + + This class wraps an Environment and exposes its reset(), step(), and state + methods as HTTP endpoints compatible with HTTPEnvClient. + + The server expects: + - Action deserialization: Converts JSON dict to Action subclass + - Observation serialization: Converts Observation subclass to JSON dict + + Example: + >>> from core.env_server import HTTPEnvServer + >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> + >>> env = CodeExecutionEnvironment() + >>> server = HTTPEnvServer(env) + >>> + >>> # Register routes with FastAPI + >>> from fastapi import FastAPI + >>> app = FastAPI() + >>> server.register_routes(app) + """ + + def __init__( + self, + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + ): + """ + Initialize HTTP server wrapper. + + Args: + env: The Environment instance to wrap + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + """ + self.env = env + self.action_cls = action_cls + self.observation_cls = observation_cls + # Create thread pool for running sync code in async context + # This is needed for environments using sync libraries (e.g., Playwright sync API) + self._executor = ThreadPoolExecutor(max_workers=1) + + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + """Run a synchronous function in the thread pool executor.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) + + def _get_valid_kwargs(self, sig, kwargs, skip_params=None): + """Filter kwargs to only include parameters accepted by the function signature.""" + if skip_params is None: + skip_params = set() + + valid_kwargs = {} + + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + + for k, v in kwargs.items(): + if k in sig.parameters or has_kwargs: + if k not in skip_params: + valid_kwargs[k] = v + + return valid_kwargs + + def register_routes(self, app: FastAPI) -> None: + """ + Register HTTP routes on a FastAPI application. + + Args: + app: FastAPI application instance + """ + + # Helper function to handle reset endpoint + async def reset_handler( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + """Reset endpoint - returns initial observation.""" + # Handle optional parameters + # Start with all fields from the request, including extra ones + kwargs = request.model_dump(exclude_unset=True) + + # Pass arguments only if environment accepts them + sig = inspect.signature(self.env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + # Run synchronous reset in thread pool to avoid blocking event loop + observation = await self._run_sync_in_thread_pool( + self.env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + + # Helper function to handle step endpoint + async def step_handler(request: StepRequest) -> StepResponse: + """Step endpoint - executes action and returns observation.""" + action_data = request.action + + # Deserialize action with Pydantic validation + try: + action = deserialize_action(action_data, self.action_cls) + except ValidationError as e: + # Return HTTP 422 with detailed validation errors + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() + ) + + # Handle optional parameters + # Start with all fields from the request, including extra ones, but exclude 'action' + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + + # Pass arguments only if environment accepts them + sig = inspect.signature(self.env.step) + valid_kwargs = self._get_valid_kwargs(sig, kwargs, skip_params={"action"}) + + # Run synchronous step in thread pool to avoid blocking event loop + observation = await self._run_sync_in_thread_pool( + self.env.step, action, **valid_kwargs + ) + + # Return serialized observation + return StepResponse(**serialize_observation(observation)) + + # Register routes using the helpers + @app.post( + "/reset", + response_model=ResetResponse, + tags=["Environment Control"], + summary="Reset the environment", + description=""" +Reset the environment to its initial state and return the first observation. + +You can optionally provide a seed for reproducibility and an episode_id for tracking. + """, + responses={ + 200: { + "description": "Environment reset successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "ready", "data": {}}, + "reward": None, + "done": False, + } + } + }, + } + }, + ) + async def reset( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + return await reset_handler(request) + + @app.post( + "/step", + response_model=StepResponse, + tags=["Environment Control"], + summary="Execute an action in the environment", + description=""" +Execute an action in the environment and receive the resulting observation. + +The action must conform to the environment's action schema, which can be +retrieved from the `/schema` endpoint. If the action is invalid, +the endpoint will return HTTP 422 with detailed validation errors. + +The response includes: +- **observation**: The environment's response to the action +- **reward**: Optional reward signal (float or None) +- **done**: Boolean indicating if the episode has terminated + """, + responses={ + 200: { + "description": "Action executed successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "success", "data": {}}, + "reward": 1.0, + "done": False, + } + } + }, + }, + 422: { + "description": "Validation error - invalid action format or values", + "content": { + "application/json": { + "example": { + "detail": [ + { + "type": "string_too_short", + "loc": ["body", "action", "message"], + "msg": "String should have at least 1 character", + "input": "", + } + ] + } + } + }, + }, + 500: {"description": "Internal server error during action execution"}, + }, + ) + async def step(request: StepRequest) -> StepResponse: + return await step_handler(request) + + # Configure and register GET endpoints declaratively + get_endpoints = [ + GetEndpointConfig( + path="/state", + handler=lambda: self.env.state, + response_model=State, + tag="State Management", + summary="Get current environment state", + description=""" +Retrieve the current internal state of the environment. + +This endpoint allows inspection of the environment state without modifying it. +The structure of the state object is defined by the environment's State model. + """, + ), + GetEndpointConfig( + path="/metadata", + handler=self.env.get_metadata, + response_model=EnvironmentMetadata, + tag="Environment Info", + summary="Get environment metadata", + description=""" +Get metadata about this environment. + +Returns information about the environment including name, description, +version, author, and documentation links. + """, + ), + GetEndpointConfig( + path="/health", + handler=lambda: HealthResponse(status="healthy"), + response_model=HealthResponse, + tag="Health", + summary="Health check", + description="Check if the environment server is running and healthy.", + ), + ] + register_get_endpoints(app, get_endpoints) + + # Register combined schema endpoint + @app.get( + "/schema", + response_model=SchemaResponse, + tags=["Schema"], + summary="Get all JSON schemas", + description=""" +Get JSON schemas for actions, observations, and state in a single response. + +Returns a combined schema object containing: +- **action**: JSON schema for actions accepted by this environment +- **observation**: JSON schema for observations returned by this environment +- **state**: JSON schema for environment state objects + +This is more efficient than calling individual schema endpoints and provides +all schema information needed to interact with the environment. + """, + responses={ + 200: { + "description": "Combined schemas retrieved successfully", + "content": { + "application/json": { + "example": { + "action": { + "type": "object", + "properties": {"message": {"type": "string"}}, + }, + "observation": { + "type": "object", + "properties": {"response": {"type": "string"}}, + }, + "state": { + "type": "object", + "properties": {"step_count": {"type": "integer"}}, + }, + } + } + }, + } + }, + ) + async def get_schemas() -> SchemaResponse: + """Return all schemas in one response.""" + return SchemaResponse( + action=self.action_cls.model_json_schema(), + observation=self.observation_cls.model_json_schema(), + state=State.model_json_schema(), + ) + + +def create_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, +) -> FastAPI: + """ + Create a FastAPI application with or without web interface. + + This function creates a FastAPI app with the web interface enabled by default, + including README integration for better user experience. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + + Returns: + FastAPI application instance with or without web interface and README integration + """ + # Check if web interface should be enabled + # This can be controlled via environment variable or build argument + enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ( + "true", + "1", + "yes", + ) + + if enable_web: + # Import web interface only when needed + from .web_interface import create_web_interface_app + + return create_web_interface_app(env, action_cls, observation_cls, env_name) + else: + # Use standard FastAPI app without web interface + return create_fastapi_app(env, action_cls, observation_cls) + + +def create_fastapi_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], +) -> FastAPI: + """Create a FastAPI application with comprehensive documentation.""" + try: + from fastapi import FastAPI + except ImportError: + raise ImportError( + "FastAPI is required. Install with: pip install fastapi uvicorn" + ) + + app = FastAPI( + title="OpenEnv Environment HTTP API", + version="1.0.0", + description=""" +# OpenEnv Environment HTTP API + +HTTP API for interacting with OpenEnv environments through a standardized interface. + +## Features + +* **Environment Reset**: Initialize or restart episodes +* **Action Execution**: Send actions and receive observations +* **State Inspection**: Query current environment state +* **Schema Access**: Retrieve JSON schemas for actions and observations + +## Workflow + +1. Call `/reset` to start a new episode and get initial observation +2. Call `/step` repeatedly with actions to interact with environment +3. Episode ends when observation returns `done: true` +4. Call `/state` anytime to inspect current environment state + +## Documentation + +* **Swagger UI**: Available at `/docs` +* **ReDoc**: Available at `/redoc` +* **OpenAPI Schema**: Available at `/openapi.json` + """, + openapi_tags=[ + { + "name": "Environment Control", + "description": "Core operations for environment interaction (reset, step)", + }, + { + "name": "State Management", + "description": "Operations for inspecting environment state", + }, + { + "name": "Environment Info", + "description": "Information about the environment", + }, + { + "name": "Schema", + "description": "JSON Schema endpoints for actions, observations, and state", + }, + {"name": "Health", "description": "Service health and status checks"}, + ], + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + contact={ + "name": "OpenEnv Team", + "url": "https://github.com/meta-pytorch/OpenEnv", + }, + license_info={ + "name": "BSD-3-Clause", + "url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE", + }, + ) + + server = HTTPEnvServer(env, action_cls, observation_cls) + server.register_routes(app) + return app diff --git a/src/core/env_server/interfaces.py b/src/core/env_server/interfaces.py index caa2d76d..b438cd66 100644 --- a/src/core/env_server/interfaces.py +++ b/src/core/env_server/interfaces.py @@ -1,118 +1,144 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import Any, Protocol, TypedDict - -from .types import Action, Observation, State - - -class Message(TypedDict): - """A message in a conversation. - - Compatible with Huggingface chat template format. - """ - - role: str - content: str - - -class ModelTokenizer(Protocol): - """Protocol for tokenizers that support chat templates. - - This protocol defines the interface that tokenizers must implement - to work with chat-based environments. It's compatible with - Huggingface transformers tokenizers. - """ - - def apply_chat_template( - self, - conversation: list[Message], - tokenize: bool = True, - return_tensors: str | None = None, - **kwargs: Any, - ) -> Any: - """Apply a chat template to format and optionally tokenize a conversation. - - Args: - conversation: List of message dictionaries with 'role' and 'content' - tokenize: Whether to tokenize the output - return_tensors: Format for returned tensors ('pt' for PyTorch) - **kwargs: Additional arguments - - Returns: - Formatted and optionally tokenized conversation - """ - ... - - def decode( - self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any - ) -> str: - """Decode token IDs back to text. - - Args: - token_ids: Token IDs to decode - skip_special_tokens: Whether to skip special tokens in output - **kwargs: Additional arguments - - Returns: - Decoded text string - """ - ... - - -class Transform(ABC): - """Transform observations to add rewards, metrics, or other modifications. - - Transforms follow the TorchRL pattern where they take an observation - and return a (potentially modified) observation. This allows for - flexible reward computation and observation augmentation. - """ - - @abstractmethod - def __call__(self, observation: Observation) -> Observation: - """Transform an observation. - - Args: - observation: The input observation - - Returns: - The transformed observation - """ - pass - - -class Environment(ABC): - """Base class for all environment servers following Gym/Gymnasium API. - - Args: - transform: Optional transform to apply to observations - """ - - def __init__(self, transform: Transform | None = None): - self.transform = transform - - @abstractmethod - def reset(self) -> Observation: - """Reset the environment and return initial observation.""" - pass - - @abstractmethod - def step(self, action: Action) -> Observation: - """Take a step in the environment.""" - pass - - @property - @abstractmethod - def state(self) -> State: - """Get the current environment state.""" - pass - - def _apply_transform(self, observation: Observation) -> Observation: - """Apply transform if one is provided.""" - if self.transform is not None: - return self.transform(observation) - return observation +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any, Optional, Protocol, TypedDict + +from .types import Action, Observation, State, EnvironmentMetadata + + +class Message(TypedDict): + """A message in a conversation. + + Compatible with Huggingface chat template format. + """ + + role: str + content: str + + +class ModelTokenizer(Protocol): + """Protocol for tokenizers that support chat templates. + + This protocol defines the interface that tokenizers must implement + to work with chat-based environments. It's compatible with + Huggingface transformers tokenizers. + """ + + def apply_chat_template( + self, + conversation: list[Message], + tokenize: bool = True, + return_tensors: str | None = None, + **kwargs: Any, + ) -> Any: + """Apply a chat template to format and optionally tokenize a conversation. + + Args: + conversation: List of message dictionaries with 'role' and 'content' + tokenize: Whether to tokenize the output + return_tensors: Format for returned tensors ('pt' for PyTorch) + **kwargs: Additional arguments + + Returns: + Formatted and optionally tokenized conversation + """ + ... + + def decode( + self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any + ) -> str: + """Decode token IDs back to text. + + Args: + token_ids: Token IDs to decode + skip_special_tokens: Whether to skip special tokens in output + **kwargs: Additional arguments + + Returns: + Decoded text string + """ + ... + + +class Transform(ABC): + """Transform observations to add rewards, metrics, or other modifications. + + Transforms follow the TorchRL pattern where they take an observation + and return a (potentially modified) observation. This allows for + flexible reward computation and observation augmentation. + """ + + @abstractmethod + def __call__(self, observation: Observation) -> Observation: + """Transform an observation. + + Args: + observation: The input observation + + Returns: + The transformed observation + """ + pass + + +class Environment(ABC): + """Base class for all environment servers following Gym/Gymnasium API. + + Args: + transform: Optional transform to apply to observations + """ + + def __init__(self, transform: Transform | None = None): + self.transform = transform + + @abstractmethod + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> Observation: + """Reset the environment and return initial observation.""" + pass + + @abstractmethod + def step( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """Take a step in the environment.""" + pass + + @property + @abstractmethod + def state(self) -> State: + """Get the current environment state.""" + pass + + def get_metadata(self) -> EnvironmentMetadata: + """ + Get metadata about this environment. + + Override this method to provide custom metadata for the environment. + Default implementation returns basic metadata derived from class name. + + Returns: + EnvironmentMetadata with environment information + """ + return EnvironmentMetadata( + name=self.__class__.__name__, + description=f"{self.__class__.__name__} environment", + version="1.0.0", + ) + + def _apply_transform(self, observation: Observation) -> Observation: + """Apply transform if one is provided.""" + if self.transform is not None: + return self.transform(observation) + return observation diff --git a/src/core/env_server/route_config.py b/src/core/env_server/route_config.py new file mode 100644 index 00000000..08807c68 --- /dev/null +++ b/src/core/env_server/route_config.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Route configuration utilities for declarative FastAPI route registration. + +This module provides utilities to reduce boilerplate in route registration +by using configuration objects instead of repeated function calls. +""" + +from dataclasses import dataclass +from typing import Callable, List, Type + +from fastapi import FastAPI +from pydantic import BaseModel + + +@dataclass +class GetEndpointConfig: + """Configuration for a simple GET endpoint.""" + + path: str + handler: Callable[[], BaseModel | dict] + response_model: Type[BaseModel] | type[dict] + tag: str + summary: str + description: str + + +def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None: + """ + Register multiple GET endpoints from configuration. + + Args: + app: FastAPI application instance + configs: List of GET endpoint configurations + """ + for config in configs: + # Capture handler in a closure to avoid non-serializable default parameter + def make_endpoint( + handler: Callable[[], BaseModel | dict], + ) -> Callable[[], BaseModel | dict]: + async def endpoint() -> BaseModel | dict: + return handler() + + return endpoint + + app.get( + config.path, + response_model=config.response_model, + tags=[config.tag], + summary=config.summary, + description=config.description, + )(make_endpoint(config.handler)) diff --git a/src/core/env_server/serialization.py b/src/core/env_server/serialization.py new file mode 100644 index 00000000..a97a0528 --- /dev/null +++ b/src/core/env_server/serialization.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared serialization and deserialization utilities for OpenEnv HTTP servers. + +This module provides common utilities for converting between JSON dictionaries +and Pydantic models (Action/Observation) to eliminate code duplication across +HTTP server and web interface implementations. +""" + +from typing import Any, Dict, Type + +from .types import Action, Observation + + +def deserialize_action( + action_data: Dict[str, Any], action_cls: Type[Action] +) -> Action: + """ + Convert JSON dict to Action instance using Pydantic validation. + + This is a basic deserialization that works for most environments. + For special cases (e.g., tensor fields, custom type conversions), + use deserialize_action_with_preprocessing(). + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + + Note: + This uses Pydantic's model_validate() for automatic validation. + """ + return action_cls.model_validate(action_data) + + +def deserialize_action_with_preprocessing( + action_data: Dict[str, Any], action_cls: Type[Action] +) -> Action: + """ + Convert JSON dict to Action instance with preprocessing for special types. + + This version handles common type conversions needed for web interfaces: + - Converting lists/strings to tensors for 'tokens' field + - Converting string action_id to int + - Other custom preprocessing as needed + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + """ + processed_data = {} + + for key, value in action_data.items(): + if key == "tokens" and isinstance(value, (list, str)): + # Convert list or string to tensor + if isinstance(value, str): + # If it's a string, try to parse it as a list of numbers + try: + import json + + value = json.loads(value) + except Exception: + # If parsing fails, treat as empty list + value = [] + if isinstance(value, list): + try: + import torch + + processed_data[key] = torch.tensor(value, dtype=torch.long) + except ImportError: + # If torch not available, keep as list + processed_data[key] = value + else: + processed_data[key] = value + elif key == "action_id" and isinstance(value, str): + # Convert action_id from string to int + try: + processed_data[key] = int(value) + except ValueError: + # If conversion fails, keep original value + processed_data[key] = value + else: + processed_data[key] = value + + return action_cls.model_validate(processed_data) + + +def serialize_observation(observation: Observation) -> Dict[str, Any]: + """ + Convert Observation instance to JSON-compatible dict using Pydantic. + + Args: + observation: Observation instance + + Returns: + Dictionary compatible with HTTPEnvClient._parse_result() + + The format matches what HTTPEnvClient expects: + { + "observation": {...}, # Observation fields + "reward": float | None, + "done": bool, + } + """ + # Use Pydantic's model_dump() for serialization + obs_dict = observation.model_dump( + exclude={ + "reward", + "done", + "metadata", + } # Exclude these from observation dict + ) + + # Extract reward and done directly from the observation + reward = observation.reward + done = observation.done + + # Return in HTTPEnvClient expected format + return { + "observation": obs_dict, + "reward": reward, + "done": done, + } diff --git a/src/core/env_server/types.py b/src/core/env_server/types.py index 70da9f3c..c3ee689c 100644 --- a/src/core/env_server/types.py +++ b/src/core/env_server/types.py @@ -4,54 +4,211 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union +from pydantic import BaseModel, Field, ConfigDict # Type aliases Scalar = Union[int, float, bool] -@dataclass(kw_only=True) -class Action: - """Base class for all environment actions.""" +class Action(BaseModel): + """Base class for all environment actions. + + All action subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + + model_config = ConfigDict( + extra="forbid", # Reject unknown fields + validate_assignment=True, # Validate on field assignment + arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc. + ) + + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the action" + ) + + +class Observation(BaseModel): + """Base class for all environment observations. - metadata: Dict[str, Any] = field(default_factory=dict) - - -@dataclass(kw_only=True) -class Observation: - """Base class for all environment observations.""" - - done: bool = False - reward: Union[bool, int, float, None] = None - metadata: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class State: - """Base class for environment state.""" - - episode_id: Optional[str] = None - step_count: int = 0 - - -@dataclass -class CodeExecResult: + All observation subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + done: bool = Field(default=False, description="Whether the episode has terminated") + reward: bool | int | float | None = Field( + default=None, description="Reward signal from the last action" + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the observation" + ) + + +class ResetRequest(BaseModel): + """Request model for environment reset.""" + + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom reset parameters + json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]}, + ) + + seed: Optional[int] = Field( + default=None, ge=0, description="Random seed for reproducible episodes" + ) + episode_id: Optional[str] = Field( + default=None, max_length=255, description="Custom episode identifier" + ) + + +class ResetResponse(BaseModel): + """Response model for environment reset.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Initial observation from the environment" + ) + reward: Optional[float] = Field( + default=None, description="Initial reward (typically None at reset)" + ) + done: bool = Field( + default=False, description="Whether episode is already done (typically False)" + ) + + +class StepRequest(BaseModel): + """Request model for environment step.""" + + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom step parameters + json_schema_extra={ + "examples": [ + {"action": {"value": 1}, "timeout_s": 30.0}, + {"action": {"value": 1}, "render": True, "verbose": False}, + ] + }, + ) + + action: Dict[str, Any] = Field( + ..., + description="Action to execute, must conform to environment's action schema", + ) + timeout_s: Optional[float] = Field( + default=None, + gt=0, + description="Optional timeout in seconds for action execution", + ) + request_id: Optional[str] = Field( + default=None, + max_length=255, + description="Optional request identifier for tracking", + ) + + +class StepResponse(BaseModel): + """Response model for environment step.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Observation resulting from the action" + ) + reward: Optional[float] = Field( + default=None, description="Reward signal from the action" + ) + done: bool = Field(default=False, description="Whether the episode has terminated") + + +class State(BaseModel): + """Base class for environment state. + + Represents internal environment state, separate from observations. + """ + + model_config = ConfigDict( + extra="allow", # Allow extra fields for flexibility + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + episode_id: Optional[str] = Field( + default=None, description="Unique identifier for the current episode" + ) + step_count: int = Field( + default=0, + ge=0, # Greater than or equal to 0 + description="Number of steps taken in the current episode", + ) + + +class CodeExecResult(BaseModel): """Result of code execution containing stdout, stderr, and exit code.""" - stdout: str - stderr: str - exit_code: int + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + stdout: str = Field(description="Standard output from code execution") + stderr: str = Field(description="Standard error from code execution") + exit_code: int = Field(description="Exit code from code execution") -@dataclass -class EnvironmentMetadata: + +class EnvironmentMetadata(BaseModel): """Metadata about an environment for documentation and UI purposes.""" - - name: str - description: str - readme_content: Optional[str] = None - version: Optional[str] = None - author: Optional[str] = None - documentation_url: Optional[str] = None + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + name: str = Field(description="Name of the environment") + description: str = Field(description="Description of what the environment does") + readme_content: Optional[str] = Field( + default=None, description="Content of the README file for the environment" + ) + version: Optional[str] = Field( + default=None, description="Version of the environment" + ) + author: Optional[str] = Field(default=None, description="Author of the environment") + documentation_url: Optional[str] = Field( + default=None, description="URL to the environment's documentation" + ) + + +class SchemaResponse(BaseModel): + """Response model for the combined schema endpoint.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + action: Dict[str, Any] = Field( + description="JSON schema for actions accepted by this environment" + ) + observation: Dict[str, Any] = Field( + description="JSON schema for observations returned by this environment" + ) + state: Dict[str, Any] = Field( + description="JSON schema for environment state objects" + ) + + +class HealthResponse(BaseModel): + """Response model for health check endpoint.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + status: str = Field(description="Health status of the environment server") diff --git a/src/core/env_server/web_interface.py b/src/core/env_server/web_interface.py index 3c36aa1d..b370cfa5 100644 --- a/src/core/env_server/web_interface.py +++ b/src/core/env_server/web_interface.py @@ -1,1613 +1,1591 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Web interface for OpenEnv environments. - -This module provides a web-based interface for interacting with OpenEnv environments, -including a two-pane layout for HumanAgent interaction and state observation. -""" - -from __future__ import annotations - -import json -import time -from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Type -from datetime import datetime - -from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request -from fastapi.responses import HTMLResponse, FileResponse -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel - -from .interfaces import Environment -from .types import Action, Observation, State, EnvironmentMetadata - - -def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata: - """ - Load environment metadata including README content. - - Args: - env: The environment instance - env_name: Optional environment name for README file lookup - - Returns: - EnvironmentMetadata with loaded information - """ - # Try to get metadata from environment if it has a method for it - if hasattr(env, 'get_metadata'): - return env.get_metadata() - - # Default metadata - metadata = EnvironmentMetadata( - name=env_name or env.__class__.__name__, - description=f"{env.__class__.__name__} environment", - version="1.0.0" - ) - - # Try to load README from file system - readme_content = _load_readme_from_filesystem(env_name) - if readme_content: - metadata.readme_content = readme_content - - return metadata - - -def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: - """ - Load README content from the filesystem. - - Tries multiple locations: - 1. Container filesystem: /app/README.md - 2. Local development: src/envs/{env_name}/README.md - 3. Environment variable: ENV_README_PATH - """ - import os - from pathlib import Path - - # Try container filesystem first - container_readme = Path("/app/README.md") - if container_readme.exists(): - try: - return container_readme.read_text(encoding='utf-8') - except Exception: - pass - - # Try environment variable path - custom_path = os.environ.get("ENV_README_PATH") - if custom_path and Path(custom_path).exists(): - try: - return Path(custom_path).read_text(encoding='utf-8') - except Exception: - pass - - # Try local development path - if env_name: - local_readme = Path(f"src/envs/{env_name}/README.md") - if local_readme.exists(): - try: - return local_readme.read_text(encoding='utf-8') - except Exception: - pass - - return None - - -@dataclass -class ActionLog: - """Log entry for an action taken.""" - timestamp: str - action: Dict[str, Any] - observation: Dict[str, Any] - reward: Optional[float] - done: bool - step_count: int - - -@dataclass -class EpisodeState: - """Current episode state for the web interface.""" - episode_id: Optional[str] - step_count: int - current_observation: Optional[Dict[str, Any]] - action_logs: List[ActionLog] - is_reset: bool = True - - -class WebInterfaceManager: - """Manages the web interface for an environment.""" - - def __init__( - self, - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - metadata: Optional[EnvironmentMetadata] = None, - ): - self.env = env - self.action_cls = action_cls - self.observation_cls = observation_cls - self.metadata = metadata or EnvironmentMetadata( - name=env.__class__.__name__, - description=f"{env.__class__.__name__} environment" - ) - self.episode_state = EpisodeState( - episode_id=None, - step_count=0, - current_observation=None, - action_logs=[] - ) - self.connected_clients: List[WebSocket] = [] - - async def connect_websocket(self, websocket: WebSocket): - """Connect a new WebSocket client.""" - await websocket.accept() - self.connected_clients.append(websocket) - - # Send current state to the new client - await self._send_state_update() - - async def disconnect_websocket(self, websocket: WebSocket): - """Disconnect a WebSocket client.""" - if websocket in self.connected_clients: - self.connected_clients.remove(websocket) - - async def _send_state_update(self): - """Send current state to all connected clients.""" - if not self.connected_clients: - return - - state_data = { - "type": "state_update", - "episode_state": asdict(self.episode_state) - } - - # Send to all connected clients - disconnected_clients = [] - for client in self.connected_clients: - try: - await client.send_text(json.dumps(state_data)) - except: - disconnected_clients.append(client) - - # Remove disconnected clients - for client in disconnected_clients: - self.connected_clients.remove(client) - - async def reset_environment(self) -> Dict[str, Any]: - """Reset the environment and update state.""" - observation = self.env.reset() - state = self.env.state - - # Update episode state - self.episode_state.episode_id = state.episode_id - self.episode_state.step_count = 0 - self.episode_state.current_observation = asdict(observation) - self.episode_state.action_logs = [] - self.episode_state.is_reset = True - - # Send state update - await self._send_state_update() - - return { - "observation": asdict(observation), - "reward": observation.reward, - "done": observation.done, - } - - async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: - """Execute a step in the environment and update state.""" - # Deserialize action - action = self._deserialize_action(action_data) - - # Execute step - observation = self.env.step(action) - state = self.env.state - - # Create action log - action_log = ActionLog( - timestamp=datetime.now().isoformat(), - action=asdict(action), - observation=asdict(observation), - reward=observation.reward, - done=observation.done, - step_count=state.step_count - ) - - # Update episode state - self.episode_state.episode_id = state.episode_id - self.episode_state.step_count = state.step_count - self.episode_state.current_observation = asdict(observation) - self.episode_state.action_logs.append(action_log) - self.episode_state.is_reset = False - - # Send state update - await self._send_state_update() - - return { - "observation": asdict(observation), - "reward": observation.reward, - "done": observation.done, - } - - def get_state(self) -> Dict[str, Any]: - """Get current environment state.""" - state = self.env.state - return asdict(state) - - def _deserialize_action(self, action_data: Dict[str, Any]) -> Action: - """Convert JSON dict to Action instance.""" - metadata = action_data.pop("metadata", {}) - - # Handle tensor fields that come from JSON as lists - processed_data = {} - for key, value in action_data.items(): - if key == "tokens" and isinstance(value, (list, str)): - # Convert list or string to tensor - if isinstance(value, str): - # If it's a string, try to parse it as a list of numbers - try: - import json - value = json.loads(value) - except: - # If parsing fails, treat as empty list - value = [] - if isinstance(value, list): - import torch - processed_data[key] = torch.tensor(value, dtype=torch.long) - else: - processed_data[key] = value - elif key == "action_id" and isinstance(value, str): - # Convert action_id from string to int - try: - processed_data[key] = int(value) - except ValueError: - # If conversion fails, keep original value - processed_data[key] = value - else: - processed_data[key] = value - - action = self.action_cls(**processed_data) - action.metadata = metadata - return action - - -def create_web_interface_app( - env: Environment, - action_cls: Type[Action], - observation_cls: Type[Observation], - env_name: Optional[str] = None, -) -> FastAPI: - """ - Create a FastAPI application with web interface for the given environment. - - Args: - env: The Environment instance to serve - action_cls: The Action subclass this environment expects - observation_cls: The Observation subclass this environment returns - env_name: Optional environment name for README loading - - Returns: - FastAPI application instance with web interface - """ - from .http_server import create_fastapi_app - - # Create the base environment app - app = create_fastapi_app(env, action_cls, observation_cls) - - # Load environment metadata - metadata = load_environment_metadata(env, env_name) - - # Create web interface manager - web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) - - # Add web interface routes - @app.get("/web", response_class=HTMLResponse) - async def web_interface(): - """Serve the web interface.""" - return get_web_interface_html(action_cls, web_manager.metadata) - - @app.get("/web/metadata") - async def web_metadata(): - """Get environment metadata.""" - return asdict(web_manager.metadata) - - @app.websocket("/ws") - async def websocket_endpoint(websocket: WebSocket): - """WebSocket endpoint for real-time updates.""" - await web_manager.connect_websocket(websocket) - try: - while True: - # Keep connection alive - await websocket.receive_text() - except WebSocketDisconnect: - await web_manager.disconnect_websocket(websocket) - - @app.post("/web/reset") - async def web_reset(): - """Reset endpoint for web interface.""" - return await web_manager.reset_environment() - - @app.post("/web/step") - async def web_step(request: Dict[str, Any]): - """Step endpoint for web interface.""" - # Check if this is a message-based request (chat environment) - if "message" in request: - message = request["message"] - # Convert message to action using the environment's message_to_action method - action = web_manager.env.message_to_action(message) - action_data = {"tokens": action.tokens.tolist()} - else: - action_data = request.get("action", {}) - - return await web_manager.step_environment(action_data) - - @app.get("/web/state") - async def web_state(): - """State endpoint for web interface.""" - return web_manager.get_state() - - return app - - -def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str: - """Generate the HTML for the web interface.""" - - # Check if this is a chat environment by looking for tokens field - is_chat_env = False - if hasattr(action_cls, '__dataclass_fields__'): - for field_name, field_info in action_cls.__dataclass_fields__.items(): - if field_name == 'tokens' and hasattr(field_info.type, '__name__') and 'Tensor' in field_info.type.__name__: - is_chat_env = True - break - - # Get action fields for dynamic form generation with enhanced metadata - action_fields = _extract_action_fields(action_cls) - - return f""" - - - - - - OpenEnv Web Interface - - - -
- -
-
- - HumanAgent Interface -
-
- - {_generate_instructions_section(metadata)} - - - {_generate_action_interface(action_fields, is_chat_env)} - - -
- - -
- - -
-

Current State

-
-
- Status: - Not initialized -
-
- Episode ID: - - -
-
- Step Count: - 0 -
-
-
-
-
- - -
-
- State Observer -
-
- -
-

Current Observation

-
- No observation yet -
-
- - -
-

Action History

-
- No actions taken yet -
-
-
-
-
- - - - - """.replace('{_generate_action_form_fields(action_fields)}', _generate_action_form_fields(action_fields)) - - -def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> str: - """Generate the instructions section with environment documentation.""" - if not metadata or not metadata.readme_content: - return '' - - # Convert markdown to HTML (basic conversion) - import re - html_content = _markdown_to_html(metadata.readme_content) - - return f''' - -
-
-

{metadata.name}

- -
-
-
- {html_content} -
-
-
- ''' - - -def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: - """Extract enhanced field metadata from Action class for form generation.""" - import typing - from typing import get_origin, get_args - - action_fields = [] - if not hasattr(action_cls, '__dataclass_fields__'): - return action_fields - - for field_name, field_info in action_cls.__dataclass_fields__.items(): - if field_name == 'metadata': - continue - - field_type = field_info.type - field_metadata = _extract_field_metadata(field_name, field_info) - - # Determine input type based on field type - input_type = _determine_input_type(field_type) - - # Check if field is required - is_required = field_info.default is field_info.default_factory - - action_fields.append({ - 'name': field_name, - 'type': input_type, - 'required': is_required, - 'description': field_metadata.get('description', ''), - 'default_value': field_metadata.get('default_value'), - 'choices': field_metadata.get('choices', []), - 'min_value': field_metadata.get('min_value'), - 'max_value': field_metadata.get('max_value'), - 'placeholder': field_metadata.get('placeholder', ''), - 'help_text': field_metadata.get('help_text', ''), - }) - - return action_fields - - -def _extract_field_metadata(field_name: str, field_info) -> Dict[str, Any]: - """Extract metadata from dataclass field including docstring and type hints.""" - import typing - from typing import get_origin, get_args, Literal, Union, Optional - - metadata = {} - - # Extract description from field docstring or annotation - if hasattr(field_info, 'metadata') and field_info.metadata: - # Check for custom metadata - for meta in field_info.metadata: - if isinstance(meta, dict): - metadata.update(meta) - - # Extract type information - field_type = field_info.type - origin = get_origin(field_type) - - # Handle Literal types for dropdown choices - if origin is Literal: - args = get_args(field_type) - metadata['choices'] = list(args) - - # Handle Optional types - if origin is Union: - args = get_args(field_type) - if len(args) == 2 and type(None) in args: - # This is Optional[SomeType] - non_none_type = args[0] if args[1] is type(None) else args[1] - metadata['optional'] = True - # Recursively check the non-None type for choices - if get_origin(non_none_type) is Literal: - metadata['choices'] = list(get_args(non_none_type)) - else: - # Regular Union type - metadata['choices'] = [str(arg) for arg in args if arg is not type(None)] - - # Handle numeric constraints - if field_type in (int, float): - # Check for common constraint patterns in field name - if 'count' in field_name.lower() or 'num' in field_name.lower(): - metadata['min_value'] = 0 - if 'id' in field_name.lower(): - metadata['min_value'] = 0 - - # Generate placeholder text - if 'message' in field_name.lower(): - metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...' - elif 'code' in field_name.lower(): - metadata['placeholder'] = 'Enter Python code here...' - elif 'tokens' in field_name.lower(): - metadata['placeholder'] = 'Enter comma-separated token IDs (e.g., 1,2,3,4,5)' - else: - metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...' - - # Generate help text based on field name and type - if 'action_id' in field_name.lower(): - metadata['help_text'] = 'The action ID to execute in the environment' - elif 'game_name' in field_name.lower(): - metadata['help_text'] = 'Name of the game or environment' - elif 'tokens' in field_name.lower(): - metadata['help_text'] = 'Token IDs as a comma-separated list of integers' - elif 'code' in field_name.lower(): - metadata['help_text'] = 'Python code to execute in the environment' - elif 'message' in field_name.lower(): - metadata['help_text'] = 'Text message to send' - - return metadata - - -def _determine_input_type(field_type) -> str: - """Determine the appropriate HTML input type for a field type.""" - import typing - from typing import get_origin, get_args, Literal, Union - - # Handle direct types - if field_type == str: - return "text" - elif field_type == int: - return "number" - elif field_type == float: - return "number" - elif field_type == bool: - return "checkbox" - - # Handle complex types - origin = get_origin(field_type) - - if origin is Literal: - return "select" - elif origin is Union: - args = get_args(field_type) - if len(args) == 2 and type(None) in args: - # Optional type - use the non-None type - non_none_type = args[0] if args[1] is type(None) else args[1] - return _determine_input_type(non_none_type) - elif all(isinstance(arg, str) for arg in args if arg is not type(None)): - return "select" - else: - return "text" - elif hasattr(field_type, '__name__') and 'Tensor' in field_type.__name__: - return "tensor" - else: - return "text" - - -def _markdown_to_html(markdown: str) -> str: - """Convert basic markdown to HTML for README display.""" - import html - import re - - # Escape HTML first - html_content = html.escape(markdown) - - # Convert headers - html_content = re.sub(r'^# (.*?)$', r'

\1

', html_content, flags=re.MULTILINE) - html_content = re.sub(r'^## (.*?)$', r'

\1

', html_content, flags=re.MULTILINE) - html_content = re.sub(r'^### (.*?)$', r'

\1

', html_content, flags=re.MULTILINE) - - # Convert code blocks - html_content = re.sub(r'```(.*?)\n(.*?)\n```', r'
\2
', html_content, flags=re.DOTALL) - html_content = re.sub(r'`([^`]+)`', r'\1', html_content) - - # Convert bold and italic - html_content = re.sub(r'\*\*(.*?)\*\*', r'\1', html_content) - html_content = re.sub(r'\*(.*?)\*', r'\1', html_content) - - # Convert lists - html_content = re.sub(r'^- (.*?)$', r'
  • \1
  • ', html_content, flags=re.MULTILINE) - html_content = re.sub(r'(
  • .*
  • )', r'', html_content, flags=re.DOTALL) - - # Convert line breaks - html_content = html_content.replace('\n', '
    ') - - return html_content - - -def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str: - """Generate either a chat interface or action form based on environment type.""" - if is_chat_env: - return _generate_chat_interface() - else: - return _generate_action_form(action_fields) - -def _generate_chat_interface() -> str: - """Generate a chat-style interface for chat environments.""" - return ''' - -
    -

    Chat Interface

    -
    -
    -
    System
    -
    Chat environment ready. Send a message to start the conversation.
    -
    -
    -
    -
    - - -
    -
    - - -
    -
    -
    - ''' - -def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str: - """Generate a traditional action form for non-chat environments.""" - return f''' - -
    -

    Take Action

    -
    - {_generate_action_form_fields(action_fields)} - -
    -
    - ''' - -def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str: - """Generate HTML form fields for action input with enhanced metadata.""" - if not action_fields: - return '

    No action fields available

    ' - - fields_html = [] - for field in action_fields: - field_html = _generate_single_field(field) - fields_html.append(field_html) - - return '\n'.join(fields_html) - - -def _generate_single_field(field: Dict[str, Any]) -> str: - """Generate HTML for a single form field with enhanced metadata.""" - field_name = field['name'] - field_type = field['type'] - required = field['required'] - placeholder = field.get('placeholder', '') - help_text = field.get('help_text', '') - choices = field.get('choices', []) - min_value = field.get('min_value') - max_value = field.get('max_value') - default_value = field.get('default_value') - - # Build label with required indicator - label_text = field_name.replace('_', ' ').title() - if required: - label_text += ' *' - - # Build input attributes - input_attrs = [] - if required: - input_attrs.append('required') - if placeholder: - input_attrs.append(f'placeholder="{placeholder}"') - if min_value is not None: - input_attrs.append(f'min="{min_value}"') - if max_value is not None: - input_attrs.append(f'max="{max_value}"') - if default_value is not None: - input_attrs.append(f'value="{default_value}"') - - attrs_str = ' '.join(input_attrs) - - if field_type == 'checkbox': - return f''' -
    - - {f'{help_text}' if help_text else ''} -
    - ''' - - elif field_type == 'select': - options_html = [] - if not required: - options_html.append(f'') - - for choice in choices: - selected = 'selected' if str(choice) == str(default_value) else '' - options_html.append(f'') - - return f''' -
    - - - {f'{help_text}' if help_text else ''} -
    - ''' - - elif field_type == 'tensor': - return f''' -
    - - - {help_text or 'Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)'} -
    - ''' - - elif field_type == 'text' and ('message' in field_name.lower() or 'code' in field_name.lower()): - return f''' -
    - - - {f'{help_text}' if help_text else ''} -
    - ''' - - else: - return f''' -
    - - - {f'{help_text}' if help_text else ''} -
    - ''' +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Web interface for OpenEnv environments. + +This module provides a web-based interface for interacting with OpenEnv environments, +including a two-pane layout for HumanAgent interaction and state observation. +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional, Type +from datetime import datetime + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse +from pydantic import BaseModel, Field, ConfigDict + +from .interfaces import Environment +from .serialization import deserialize_action_with_preprocessing, serialize_observation +from .types import Action, Observation, State, EnvironmentMetadata + + +def load_environment_metadata( + env: Environment, env_name: Optional[str] = None +) -> EnvironmentMetadata: + """ + Load environment metadata including README content. + + Args: + env: The environment instance + env_name: Optional environment name for README file lookup + + Returns: + EnvironmentMetadata with loaded information + """ + # Try to get metadata from environment if it has a method for it + if hasattr(env, "get_metadata"): + return env.get_metadata() + + # Default metadata + metadata = EnvironmentMetadata( + name=env_name or env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + version="1.0.0", + ) + + # Try to load README from file system + readme_content = _load_readme_from_filesystem(env_name) + if readme_content: + metadata.readme_content = readme_content + + return metadata + + +def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: + """ + Load README content from the filesystem. + + Tries multiple locations: + 1. Container filesystem: /app/README.md + 2. Local development: src/envs/{env_name}/README.md + 3. Environment variable: ENV_README_PATH + """ + import os + from pathlib import Path + + # Try container filesystem first + container_readme = Path("/app/README.md") + if container_readme.exists(): + try: + return container_readme.read_text(encoding="utf-8") + except Exception: + pass + + # Try environment variable path + custom_path = os.environ.get("ENV_README_PATH") + if custom_path and Path(custom_path).exists(): + try: + return Path(custom_path).read_text(encoding="utf-8") + except Exception: + pass + + # Try local development path + if env_name: + local_readme = Path(f"src/envs/{env_name}/README.md") + if local_readme.exists(): + try: + return local_readme.read_text(encoding="utf-8") + except Exception: + pass + + return None + + +class ActionLog(BaseModel): + """Log entry for an action taken.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + timestamp: str = Field(description="Timestamp when action was taken") + action: Dict[str, Any] = Field(description="Action that was taken") + observation: Dict[str, Any] = Field(description="Observation returned from action") + reward: Optional[float] = Field( + default=None, description="Reward received from action" + ) + done: bool = Field(description="Whether the episode is done after this action") + step_count: int = Field(description="Step count when this action was taken") + + +class EpisodeState(BaseModel): + """Current episode state for the web interface.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + episode_id: Optional[str] = Field(default=None, description="Current episode ID") + step_count: int = Field(description="Current step count in episode") + current_observation: Optional[Dict[str, Any]] = Field( + default=None, description="Current observation" + ) + action_logs: List[ActionLog] = Field( + default_factory=list, description="List of action logs" + ) + is_reset: bool = Field( + default=True, description="Whether the episode has been reset" + ) + + +class WebInterfaceManager: + """Manages the web interface for an environment.""" + + def __init__( + self, + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + metadata: Optional[EnvironmentMetadata] = None, + ): + self.env = env + self.action_cls = action_cls + self.observation_cls = observation_cls + self.metadata = metadata or EnvironmentMetadata( + name=env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + ) + self.episode_state = EpisodeState( + episode_id=None, step_count=0, current_observation=None, action_logs=[] + ) + self.connected_clients: List[WebSocket] = [] + + async def connect_websocket(self, websocket: WebSocket): + """Connect a new WebSocket client.""" + await websocket.accept() + self.connected_clients.append(websocket) + + # Send current state to the new client + await self._send_state_update() + + async def disconnect_websocket(self, websocket: WebSocket): + """Disconnect a WebSocket client.""" + if websocket in self.connected_clients: + self.connected_clients.remove(websocket) + + async def _send_state_update(self): + """Send current state to all connected clients.""" + if not self.connected_clients: + return + + state_data = { + "type": "state_update", + "episode_state": self.episode_state.model_dump(), + } + + # Send to all connected clients + disconnected_clients = [] + for client in self.connected_clients: + try: + await client.send_text(json.dumps(state_data)) + except Exception: + disconnected_clients.append(client) + + # Remove disconnected clients + for client in disconnected_clients: + self.connected_clients.remove(client) + + async def reset_environment(self) -> Dict[str, Any]: + """Reset the environment and update state.""" + observation: Observation = self.env.reset() + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = 0 + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs = [] + self.episode_state.is_reset = True + + # Send state update + await self._send_state_update() + + return serialized + + async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: + """Execute a step in the environment and update state.""" + # Deserialize action with preprocessing for web interface special cases + action: Action = deserialize_action_with_preprocessing( + action_data, self.action_cls + ) + + # Execute step + observation: Observation = self.env.step(action) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Create action log + action_log = ActionLog( + timestamp=datetime.now().isoformat(), + action=action.model_dump(exclude={"metadata"}), + observation=serialized["observation"], + reward=observation.reward, + done=observation.done, + step_count=state.step_count, + ) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = state.step_count + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs.append(action_log) + self.episode_state.is_reset = False + + # Send state update + await self._send_state_update() + + return serialized + + def get_state(self) -> Dict[str, Any]: + """Get current environment state.""" + state: State = self.env.state + return state.model_dump() + + +def create_web_interface_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, +) -> FastAPI: + """ + Create a FastAPI application with web interface for the given environment. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + + Returns: + FastAPI application instance with web interface + """ + from .http_server import create_fastapi_app + + # Create the base environment app + app = create_fastapi_app(env, action_cls, observation_cls) + + # Load environment metadata + metadata = load_environment_metadata(env, env_name) + + # Create web interface manager + web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + + # Add web interface routes + @app.get("/web", response_class=HTMLResponse) + async def web_interface(): + """Serve the web interface.""" + return get_web_interface_html(action_cls, web_manager.metadata) + + @app.get("/web/metadata") + async def web_metadata(): + """Get environment metadata.""" + return web_manager.metadata.model_dump() + + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """WebSocket endpoint for real-time updates.""" + await web_manager.connect_websocket(websocket) + try: + while True: + # Keep connection alive + await websocket.receive_text() + except WebSocketDisconnect: + await web_manager.disconnect_websocket(websocket) + + @app.post("/web/reset") + async def web_reset(): + """Reset endpoint for web interface.""" + return await web_manager.reset_environment() + + @app.post("/web/step") + async def web_step(request: Dict[str, Any]): + """Step endpoint for web interface.""" + # Check if this is a message-based request (chat environment) + if "message" in request: + message = request["message"] + # Convert message to action using the environment's message_to_action method + action = web_manager.env.message_to_action(message) + action_data = {"tokens": action.tokens.tolist()} + else: + action_data = request.get("action", {}) + + return await web_manager.step_environment(action_data) + + @app.get("/web/state") + async def web_state(): + """State endpoint for web interface.""" + return web_manager.get_state() + + return app + + +def get_web_interface_html( + action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None +) -> str: + """Generate the HTML for the web interface.""" + + # Check if this is a chat environment by looking for tokens field + is_chat_env = False + if hasattr(action_cls, "model_fields"): + for field_name, field_info in action_cls.model_fields.items(): + if ( + field_name == "tokens" + and hasattr(field_info.annotation, "__name__") + and "Tensor" in field_info.annotation.__name__ + ): + is_chat_env = True + break + + # Get action fields for dynamic form generation with enhanced metadata + action_fields = _extract_action_fields(action_cls) + + return f""" + + + + + + OpenEnv Web Interface + + + +
    + +
    +
    + + HumanAgent Interface +
    +
    + + {_generate_instructions_section(metadata)} + + + {_generate_action_interface(action_fields, is_chat_env)} + + +
    + + +
    + + +
    +

    Current State

    +
    +
    + Status: + Not initialized +
    +
    + Episode ID: + - +
    +
    + Step Count: + 0 +
    +
    +
    +
    +
    + + +
    +
    + State Observer +
    +
    + +
    +

    Current Observation

    +
    + No observation yet +
    +
    + + +
    +

    Action History

    +
    + No actions taken yet +
    +
    +
    +
    +
    + + + + + """.replace( + "{_generate_action_form_fields(action_fields)}", + _generate_action_form_fields(action_fields), + ) + + +def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> str: + """Generate the instructions section with environment documentation.""" + if not metadata or not metadata.readme_content: + return "" + + html_content = _markdown_to_html(metadata.readme_content) + + return f""" + +
    +
    +

    {metadata.name}

    + +
    +
    +
    + {html_content} +
    +
    +
    + """ + + +def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: + """Extract enhanced field metadata from Action class for form generation.""" + # Use Pydantic's JSON schema generation for robust metadata extraction + try: + schema = action_cls.model_json_schema() + except AttributeError: + # Fallback for non-Pydantic v2 models or if something goes wrong + return [] + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + action_fields = [] + + for field_name, field_info in properties.items(): + if field_name == "metadata": + continue + + # JSON schema "type" can be a string or list/undefined + # Determine our internal input type + input_type = _determine_input_type_from_schema(field_info, field_name) + + is_required = field_name in required_fields + + action_fields.append( + { + "name": field_name, + "type": input_type, + "required": is_required, + "description": field_info.get("description", ""), + "default_value": field_info.get("default"), + "choices": field_info.get("enum"), + "min_value": field_info.get("minimum"), + "max_value": field_info.get("maximum"), + "min_length": field_info.get("minLength"), + "max_length": field_info.get("maxLength"), + "pattern": field_info.get("pattern"), + "placeholder": _generate_placeholder(field_name, field_info), + "help_text": _generate_help_text(field_name, field_info), + } + ) + + return action_fields + + +def _determine_input_type_from_schema( + field_info: Dict[str, Any], field_name: str +) -> str: + """Determine the appropriate HTML input type from JSON schema info.""" + schema_type = field_info.get("type") + + # Check for specific tensor field convention + if "tokens" in field_name.lower(): + return "tensor" + + if "enum" in field_info: + return "select" + + if schema_type == "boolean": + return "checkbox" + + if schema_type == "integer" or schema_type == "number": + return "number" + + if schema_type == "string": + # Check if it should be a textarea + if ( + field_info.get("maxLength", 0) > 100 + or "message" in field_name.lower() + or "code" in field_name.lower() + ): + return "textarea" + return "text" + + # Default fallback + return "text" + + +def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate placeholder text.""" + if "message" in field_name.lower(): + return f"Enter {field_name.replace('_', ' ')}..." + elif "code" in field_name.lower(): + return "Enter Python code here..." + elif "tokens" in field_name.lower(): + return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" + else: + return f"Enter {field_name.replace('_', ' ')}..." + + +def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate help text.""" + description = field_info.get("description", "") + if description: + return description + + if "action_id" in field_name.lower(): + return "The action ID to execute in environment" + elif "game_name" in field_name.lower(): + return "Name of game or environment" + elif "tokens" in field_name.lower(): + return "Token IDs as a comma-separated list of integers" + elif "code" in field_name.lower(): + return "Python code to execute in environment" + elif "message" in field_name.lower(): + return "Text message to send" + + return "" + + +def _markdown_to_html(markdown: str) -> str: + """Convert basic markdown to HTML for README display.""" + import html + import re + + # Escape HTML first + html_content = html.escape(markdown) + + # Convert headers + html_content = re.sub( + r"^# (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE + ) + html_content = re.sub( + r"^## (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE + ) + html_content = re.sub( + r"^### (.*?)$", r"

    \1

    ", html_content, flags=re.MULTILINE + ) + + # Convert code blocks + html_content = re.sub( + r"```(.*?)\n(.*?)\n```", + r"
    \2
    ", + html_content, + flags=re.DOTALL, + ) + html_content = re.sub(r"`([^`]+)`", r"\1", html_content) + + # Convert bold and italic + html_content = re.sub(r"\*\*(.*?)\*\*", r"\1", html_content) + html_content = re.sub(r"\*(.*?)\*", r"\1", html_content) + + # Convert lists + html_content = re.sub( + r"^- (.*?)$", r"
  • \1
  • ", html_content, flags=re.MULTILINE + ) + html_content = re.sub( + r"(
  • .*
  • )", r"", html_content, flags=re.DOTALL + ) + + # Convert line breaks + html_content = html_content.replace("\n", "
    ") + + return html_content + + +def _generate_action_interface( + action_fields: List[Dict[str, Any]], is_chat_env: bool +) -> str: + """Generate either a chat interface or action form based on environment type.""" + if is_chat_env: + return _generate_chat_interface() + else: + return _generate_action_form(action_fields) + + +def _generate_chat_interface() -> str: + """Generate a chat-style interface for chat environments.""" + return """ + +
    +

    Chat Interface

    +
    +
    +
    System
    +
    Chat environment ready. Send a message to start the conversation.
    +
    +
    +
    +
    + + +
    +
    + + +
    +
    +
    + """ + + +def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str: + """Generate a traditional action form for non-chat environments.""" + return f""" + +
    +

    Take Action

    +
    + {_generate_action_form_fields(action_fields)} + +
    +
    + """ + + +def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str: + """Generate HTML form fields for action input with enhanced metadata.""" + if not action_fields: + return "

    No action fields available

    " + + fields_html = [] + for field in action_fields: + field_html = _generate_single_field(field) + fields_html.append(field_html) + + return "\n".join(fields_html) + + +def _generate_single_field(field: Dict[str, Any]) -> str: + """Generate HTML for a single form field with enhanced metadata.""" + field_name = field["name"] + field_type = field["type"] + required = field["required"] + placeholder = field.get("placeholder", "") + help_text = field.get("help_text", "") + choices = field.get("choices", []) + min_value = field.get("min_value") + max_value = field.get("max_value") + default_value = field.get("default_value") + min_length = field.get("min_length") + max_length = field.get("max_length") + pattern = field.get("pattern") + + # Build label with required indicator + label_text = field_name.replace("_", " ").title() + if required: + label_text += ' *' + + # Build input attributes + input_attrs = [] + if required: + input_attrs.append("required") + if placeholder: + input_attrs.append(f'placeholder="{placeholder}"') + if min_value is not None: + input_attrs.append(f'min="{min_value}"') + if max_value is not None: + input_attrs.append(f'max="{max_value}"') + if min_length is not None: + input_attrs.append(f'minlength="{min_length}"') + if max_length is not None: + input_attrs.append(f'maxlength="{max_length}"') + if pattern is not None: + input_attrs.append(f'pattern="{pattern}"') + if default_value is not None: + input_attrs.append(f'value="{default_value}"') + + attrs_str = " ".join(input_attrs) + + if field_type == "checkbox": + checked = "checked" if default_value is True else "" + return f''' +
    + + {f'{help_text}' if help_text else ""} +
    + ''' + + elif field_type == "select": + options_html = [] + if not required: + options_html.append(f'') + + for choice in choices: + selected = "selected" if str(choice) == str(default_value) else "" + options_html.append( + f'' + ) + + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' + + elif field_type == "tensor": + return f''' +
    + + + {help_text or "Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)"} +
    + ''' + + elif field_type == "textarea": + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' + + else: + return f''' +
    + + + {f'{help_text}' if help_text else ""} +
    + ''' diff --git a/src/core/http_env_client.py b/src/core/http_env_client.py index 16bbfa5d..007ef6a5 100644 --- a/src/core/http_env_client.py +++ b/src/core/http_env_client.py @@ -1,203 +1,236 @@ -""" -core/runner_env.py -Minimal HTTP-based environment client. -- Talks to a single env worker exposing: POST /reset, POST /step - -Future hooks (commented below) for: -- episode_id, seed on reset -- request_id on step -- custom headers (auth/trace) -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar - -import requests - -from .client_types import StepResult -from .containers.runtime import LocalDockerProvider - -if TYPE_CHECKING: - from .containers.runtime import ContainerProvider - -ActT = TypeVar("ActT") -ObsT = TypeVar("ObsT") -EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") - - -class HTTPEnvClient(ABC, Generic[ActT, ObsT]): - def __init__( - self, - base_url: str, - request_timeout_s: float = 15.0, - default_headers: Optional[Dict[str, str]] = None, - provider: Optional["ContainerProvider"] = None, - ): - self._base = base_url.rstrip("/") - self._timeout = float(request_timeout_s) - self._http = requests.Session() - self._headers = default_headers or {} - self._provider = provider - - @classmethod - def from_docker_image( - cls: Type[EnvClientT], - image: str, - provider: Optional["ContainerProvider"] = None, - **kwargs: Any, - ) -> EnvClientT: - """ - Create an environment client by spinning up a Docker container locally. - - This is a development utility that: - 1. Starts a Docker container from the specified image - 2. Waits for the server to be ready - 3. Creates and returns a client instance connected to the container - - Note: The container lifecycle management is left to the user or higher-level - orchestration. The container will keep running until manually stopped. - - Args: - image: Docker image name to run (e.g., "echo-env:latest") - provider: Container provider to use (defaults to LocalDockerProvider) - **kwargs: Additional arguments to pass to provider.start_container() - (e.g., env_vars, port) - - Returns: - An instance of the client class connected to the running container - - Example: - >>> from envs.coding_env.client import CodingEnv - >>> from envs.coding_env.models import CodeAction - >>> - >>> # Create environment from image - >>> env = CodingEnv.from_docker_image("coding-env:latest") - >>> - >>> # Create environment with custom env vars - >>> env = CodingEnv.from_docker_image( - ... "coding-env:latest", - ... env_vars={"MY_VAR": "value"} - ... ) - >>> - >>> # Use the environment - >>> result = env.reset() - >>> print(result.observation) - >>> - >>> step_result = env.step(CodeAction(code="print('hello')")) - >>> print(step_result.observation.stdout) - >>> - >>> # Cleanup (optional) - >>> env.close() - """ - - # Use default provider if none provided - if provider is None: - provider = LocalDockerProvider() - - # 1. Start container with optional kwargs (e.g., env_vars, port) - base_url = provider.start_container(image, **kwargs) - - # 2. Wait for server to be ready - provider.wait_for_ready(base_url) - - # 3. Create and return client instance with provider reference - return cls(base_url=base_url, provider=provider) - - @classmethod - def from_hub(cls: Type[EnvClientT], repo_id: str, provider: Optional["ContainerProvider"] = None, **kwargs: Any) -> EnvClientT: - """ - Create an environment client by pulling from a Hugging Face model hub. - """ - - if provider is None: - provider = LocalDockerProvider() - - if "tag" in kwargs: - tag = kwargs["tag"] - else: - tag = "latest" - - base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" - - return cls.from_docker_image(image=base_url, provider=provider) - - @abstractmethod - def _step_payload(self, action: ActT) -> dict: - """Convert an Action object to the JSON body expected by the env server.""" - raise NotImplementedError - - @abstractmethod - def _parse_result(self, payload: dict) -> StepResult[ObsT]: - """Convert a JSON response from the env server to StepResult[ObsT].""" - raise NotImplementedError - - @abstractmethod - def _parse_state(self, payload: dict) -> Any: - """Convert a JSON response from the state endpoint to a State object.""" - raise NotImplementedError - - # ---------- Environment Server Interface Methods ---------- - def reset(self) -> StepResult[ObsT]: - body: Dict[str, Any] = {} - # TODO: later: - # body["seed"] = seed - # body["episode_id"] = episode_id - r = self._http.post( - f"{self._base}/reset", - json=body, - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_result(r.json()) - - def step(self, action: ActT) -> StepResult[ObsT]: - body: Dict[str, Any] = { - "action": self._step_payload(action), - "timeout_s": int(self._timeout), - } - # TODO: later: - # body["request_id"] = str(uuid.uuid4()) - # body["episode_id"] = current_episode_id - r = self._http.post( - f"{self._base}/step", - json=body, - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_result(r.json()) - - def state(self) -> Any: - """ - Get the current environment state from the server. - - Returns: - State object with environment state information (e.g., episode_id, step_count) - - Example: - >>> client = EchoEnv.from_docker_image("echo-env:latest") - >>> result = client.reset() - >>> state = client.state() - >>> print(state.episode_id) - >>> print(state.step_count) - """ - r = self._http.get( - f"{self._base}/state", - headers=self._headers, - timeout=self._timeout, - ) - r.raise_for_status() - return self._parse_state(r.json()) - - def close(self) -> None: - """ - Close the environment and clean up resources. - - If this client was created via from_docker_image(), this will stop - and remove the associated container. - """ - if self._provider is not None: - self._provider.stop_container() +""" +core/runner_env.py +Minimal HTTP-based environment client. +- Talks to a single env worker exposing: POST /reset, POST /step + +Future hooks (commented below) for: +- episode_id, seed on reset +- request_id on step +- custom headers (auth/trace) +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +import requests + +from .client_types import StepResult +from .containers.runtime import LocalDockerProvider + +if TYPE_CHECKING: + from .containers.runtime import ContainerProvider + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +EnvClientT = TypeVar("EnvClientT", bound="HTTPEnvClient") + + +class HTTPEnvClient(ABC, Generic[ActT, ObsT]): + def __init__( + self, + base_url: str, + request_timeout_s: float = 15.0, + default_headers: Optional[Dict[str, str]] = None, + provider: Optional["ContainerProvider"] = None, + ): + self._base = base_url.rstrip("/") + self._timeout = float(request_timeout_s) + self._http = requests.Session() + self._headers = default_headers or {} + self._provider = provider + + @classmethod + def from_docker_image( + cls: Type[EnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by spinning up a Docker container locally. + + This is a development utility that: + 1. Starts a Docker container from the specified image + 2. Waits for the server to be ready + 3. Creates and returns a client instance connected to the container + + Note: The container lifecycle management is left to the user or higher-level + orchestration. The container will keep running until manually stopped. + + Args: + image: Docker image name to run (e.g., "echo-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + (e.g., env_vars, port) + + Returns: + An instance of the client class connected to the running container + + Example: + >>> from envs.coding_env.client import CodingEnv + >>> from envs.coding_env.models import CodeAction + >>> + >>> # Create environment from image + >>> env = CodingEnv.from_docker_image("coding-env:latest") + >>> + >>> # Create environment with custom env vars + >>> env = CodingEnv.from_docker_image( + ... "coding-env:latest", + ... env_vars={"MY_VAR": "value"} + ... ) + >>> + >>> # Use the environment + >>> result = env.reset() + >>> print(result.observation) + >>> + >>> step_result = env.step(CodeAction(code="print('hello')")) + >>> print(step_result.observation.stdout) + >>> + >>> # Cleanup (optional) + >>> env.close() + """ + + # Use default provider if none provided + if provider is None: + provider = LocalDockerProvider() + + # 1. Start container with optional kwargs (e.g., env_vars, port) + base_url = provider.start_container(image, **kwargs) + + # 2. Wait for server to be ready + provider.wait_for_ready(base_url) + + # 3. Create and return client instance with provider reference + return cls(base_url=base_url, provider=provider) + + @classmethod + def from_hub( + cls: Type[EnvClientT], + repo_id: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by pulling from a Hugging Face model hub. + """ + + if provider is None: + provider = LocalDockerProvider() + + if "tag" in kwargs: + tag = kwargs["tag"] + else: + tag = "latest" + + base_url = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + + return cls.from_docker_image(image=base_url, provider=provider) + + @abstractmethod + def _step_payload(self, action: ActT) -> dict: + """Convert an Action object to the JSON body expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: dict) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: dict) -> Any: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + # ---------- Environment Server Interface Methods ---------- + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + - Any environment-specific reset parameters + + Returns: + StepResult containing initial observation + + Example: + >>> env.reset(seed=42, episode_id="ep-001") + """ + body: Dict[str, Any] = kwargs.copy() + r = self._http.post( + f"{self._base}/reset", + json=body, + headers=self._headers, + timeout=self._timeout, + ) + r.raise_for_status() + return self._parse_result(r.json()) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment with optional parameters. + + Args: + action: The action to execute + **kwargs: Optional parameters passed to the environment's step method. + Common parameters include: + - timeout_s: Execution timeout in seconds + - request_id: Request identifier for tracking + - render: Whether to render the environment + - Any environment-specific step parameters + + Returns: + StepResult containing observation, reward, and done status + + Example: + >>> env.step(action, timeout_s=30.0, request_id="req-123", render=True) + """ + body: Dict[str, Any] = { + "action": self._step_payload(action), + **kwargs # Forward all additional parameters + } + r = self._http.post( + f"{self._base}/step", + json=body, + headers=self._headers, + timeout=self._timeout, + ) + r.raise_for_status() + return self._parse_result(r.json()) + + def state(self) -> Any: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information (e.g., episode_id, step_count) + + Example: + >>> client = EchoEnv.from_docker_image("echo-env:latest") + >>> result = client.reset() + >>> state = client.state() + >>> print(state.episode_id) + >>> print(state.step_count) + """ + r = self._http.get( + f"{self._base}/state", + headers=self._headers, + timeout=self._timeout, + ) + r.raise_for_status() + return self._parse_state(r.json()) + + def close(self) -> None: + """ + Close the environment and clean up resources. + + If this client was created via from_docker_image(), this will stop + and remove the associated container. + """ + if self._provider is not None: + self._provider.stop_container() diff --git a/src/envs/echo_env/models.py b/src/envs/echo_env/models.py index c962629b..88f5da5e 100644 --- a/src/envs/echo_env/models.py +++ b/src/envs/echo_env/models.py @@ -1,36 +1,45 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Data models for the Echo Environment. - -The Echo environment is a simple test environment that echoes back messages. -""" - -from dataclasses import dataclass - -# Support both in-repo and standalone imports -try: - # In-repo imports (when running from OpenEnv repository) - from core.env_server.types import Action, Observation -except ImportError: - # Standalone imports (when environment is standalone with openenv-core from pip) - from openenv_core.env_server.types import Action, Observation - - -@dataclass(kw_only=True) -class EchoAction(Action): - """Action for the Echo environment - just a message to echo.""" - - message: str - - -@dataclass(kw_only=True) -class EchoObservation(Observation): - """Observation from the Echo environment - the echoed message.""" - - echoed_message: str - message_length: int = 0 \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Data models for the Echo Environment. + +The Echo environment is a simple test environment that echoes back messages. +""" + +from pydantic import Field + +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from core.env_server.types import Action, Observation +except ImportError: + # Standalone imports (when environment is standalone with openenv-core from pip) + from openenv_core.env_server.types import Action, Observation + + +class EchoAction(Action): + """Action for the Echo environment - just a message to echo.""" + + message: str = Field( + ..., + min_length=1, + description="Message to echo back" + ) + + +class EchoObservation(Observation): + """Observation from the Echo environment - the echoed message.""" + + echoed_message: str = Field( + ..., + description="The echoed message from the environment" + ) + message_length: int = Field( + default=0, + ge=0, + description="Length of the echoed message" + ) \ No newline at end of file diff --git a/src/envs/echo_env/server/echo_environment.py b/src/envs/echo_env/server/echo_environment.py index 53b383af..b1eb9619 100644 --- a/src/envs/echo_env/server/echo_environment.py +++ b/src/envs/echo_env/server/echo_environment.py @@ -1,102 +1,102 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Echo Environment Implementation. - -A simple test environment that echoes back messages sent to it. -Perfect for testing HTTP server infrastructure. -""" - -from uuid import uuid4 - -# Support both in-repo and standalone imports -try: - # In-repo imports (when running from OpenEnv repository) - from core.env_server.interfaces import Environment - from core.env_server.types import State - from ..models import EchoAction, EchoObservation -except ImportError: - # Standalone imports (when environment is standalone with openenv-core from pip) - from openenv_core.env_server.interfaces import Environment - from openenv_core.env_server.types import State - from models import EchoAction, EchoObservation - - -class EchoEnvironment(Environment): - """ - A simple echo environment that echoes back messages. - - This environment is designed for testing the HTTP server infrastructure. - It maintains minimal state and simply echoes back whatever message it receives. - - Example: - >>> env = EchoEnvironment() - >>> obs = env.reset() - >>> print(obs.echoed_message) # "Echo environment ready!" - >>> - >>> obs = env.step(EchoAction(message="Hello")) - >>> print(obs.echoed_message) # "Hello" - >>> print(obs.message_length) # 5 - """ - - def __init__(self): - """Initialize the echo environment.""" - self._state = State(episode_id=str(uuid4()), step_count=0) - self._reset_count = 0 - - def reset(self) -> EchoObservation: - """ - Reset the environment. - - Returns: - EchoObservation with a ready message - """ - self._state = State(episode_id=str(uuid4()), step_count=0) - self._reset_count += 1 - - return EchoObservation( - echoed_message="Echo environment ready!", - message_length=0, - done=False, - reward=0.0, - ) - - def step(self, action: EchoAction) -> EchoObservation: # type: ignore[override] - """ - Execute a step in the environment by echoing the message. - - Args: - action: EchoAction containing the message to echo - - Returns: - EchoObservation with the echoed message and its length - """ - self._state.step_count += 1 - - message = action.message - length = len(message) - - # Simple reward: longer messages get higher rewards - reward = length * 0.1 - - return EchoObservation( - echoed_message=message, - message_length=length, - done=False, - reward=reward, - metadata={"original_message": message, "step": self._state.step_count}, - ) - - @property - def state(self) -> State: - """ - Get the current environment state. - - Returns: - Current State with episode_id and step_count - """ - return self._state +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Echo Environment Implementation. + +A simple test environment that echoes back messages sent to it. +Perfect for testing HTTP server infrastructure. +""" + +from uuid import uuid4 + +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from core.env_server.interfaces import Environment + from core.env_server.types import State + from ..models import EchoAction, EchoObservation +except ImportError: + # Standalone imports (when environment is standalone with openenv-core from pip) + from openenv_core.env_server.interfaces import Environment + from openenv_core.env_server.types import State + from models import EchoAction, EchoObservation + + +class EchoEnvironment(Environment): + """ + A simple echo environment that echoes back messages. + + This environment is designed for testing the HTTP server infrastructure. + It maintains minimal state and simply echoes back whatever message it receives. + + Example: + >>> env = EchoEnvironment() + >>> obs = env.reset() + >>> print(obs.echoed_message) # "Echo environment ready!" + >>> + >>> obs = env.step(EchoAction(message="Hello")) + >>> print(obs.echoed_message) # "Hello" + >>> print(obs.message_length) # 5 + """ + + def __init__(self): + """Initialize the echo environment.""" + self._state: State = State(episode_id=str(uuid4()), step_count=0) + self._reset_count: int = 0 + + def reset(self) -> EchoObservation: + """ + Reset the environment. + + Returns: + EchoObservation with a ready message + """ + self._state: State = State(episode_id=str(uuid4()), step_count=0) + self._reset_count += 1 + + return EchoObservation( + echoed_message="Echo environment ready!", + message_length=0, + done=False, + reward=0.0, + ) + + def step(self, action: EchoAction) -> EchoObservation: # type: ignore[override] + """ + Execute a step in the environment by echoing the message. + + Args: + action: EchoAction containing the message to echo + + Returns: + EchoObservation with the echoed message and its length + """ + self._state.step_count += 1 + + message: str = action.message + length: int = len(message) + + # Simple reward: longer messages get higher rewards + reward: float = length * 0.1 + + return EchoObservation( + echoed_message=message, + message_length=length, + done=False, + reward=reward, + metadata={"original_message": message, "step": self._state.step_count}, + ) + + @property + def state(self) -> State: + """ + Get the current environment state. + + Returns: + Current State with episode_id and step_count + """ + return self._state