In [2]:
from abc import ABC, abstractmethod
from typing import Any

import gymnasium as gym
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam


class Environment(gym.Env, ABC):
    """The base environment class for agents to interact with in the CRAB framework.

    Crab Environment is a subclass of `gymnasium.Env` and is designed to be a base class
    for all environments in the CRAB.  Your must implement two functions
    `get_action_schema` and `convert_tool_call_to_action` to make the environment
    compatible with OpenAI tool use API.
    """

    @abstractmethod
    def get_description(self) -> str:
        """Get the description of the environment, which can be used as a part of the
        agent prompt.

        Returns:
            A string description of the environment.
        """

    @abstractmethod
    def get_action_schema(self) -> list[ChatCompletionToolParam]:
        """Get the tool schema for the action space of the environment.

        The schema provides detailed descriptions of the whole actions space and their
        parameters that represent all the possible actions in the tool calling format,
        which can be directly used in the OpenAI API. It should be comprehensive and do
        not produce any misunderstanding for a human user.

        Returns:
            A list of tool schema.
        """
        ...

    @abstractmethod
    def convert_tool_call_to_action(self, tool_name: str, parameters: dict) -> Any:
        """Convert a tool call to the actual action space in the environment.

        Args:
            tool_name: The name of the tool.
            parameters: The parameters of the tool call.
        """
        ...

In [3]:
import gymnasium as gym
from gymnasium.envs.classic_control.acrobot import AcrobotEnv

from crab.core.decorators import action
from refactor_demo.envs.multi_env import MultiEnv


@action
def left():
    """apply -1 torque to the actuated joint"""


@action
def right():
    """apply +1 torque to the actuated joint"""


@action
def no_torque():
    """apply 0 torque to the actuated joint"""


class CrabAcrobotEnv(AcrobotEnv, Environment):
    def get_description(self) -> str:
        """Get the description of the environment, which can be used as a part of the
        agent prompt.

        Returns:
            A string description of the environment.
        """
        return """The system consists of two links connected linearly to form a chain, with one end of \
the chain fixed. The joint between the two links is actuated. The goal is to apply \
torques on the actuated joint to swing the free end of the linear chain above a \
given height while starting from the initial state of hanging downwards.

    ## Observation Space

    The observation is a `ndarray` with shape `(6,)` that provides information about the
    two rotational joint angles as well as their angular velocities:

    | Num | Observation                  | Min                 | Max               |
    |-----|------------------------------|---------------------|-------------------|
    | 0   | Cosine of `theta1`           | -1                  | 1                 |
    | 1   | Sine of `theta1`             | -1                  | 1                 |
    | 2   | Cosine of `theta2`           | -1                  | 1                 |
    | 3   | Sine of `theta2`             | -1                  | 1                 |
    | 4   | Angular velocity of `theta1` | ~ -12.567 (-4 * pi) | ~ 12.567 (4 * pi) |
    | 5   | Angular velocity of `theta2` | ~ -28.274 (-9 * pi) | ~ 28.274 (9 * pi) |
"""

    def get_action_schema(self) -> list[ChatCompletionToolParam]:
        """Get the tool schema for the action space of the environment.

        The schema provides detailed descriptions of the whole actions space and their
        parameters that represent all the possible actions in the tool calling format,
        which can be directly used in the OpenAI API. It should be comprehensive and do
        not produce any misunderstanding for a human user.

        Returns:
            A list of tool schema.
        """
        result = []
        result.append(left.to_openai_json_schema())
        result.append(right.to_openai_json_schema())
        result.append(no_torque.to_openai_json_schema())
        return result

    MAP = {"left": 0, "no_torque": 1, "right": 2}

    def convert_tool_call_to_action(self, tool_name: str, parameters: dict) -> Any:
        """Convert a tool call to the actual action space in the environment.

        Args:
            tool_name: The name of the tool.
            parameters: The parameters of the tool call.
        """
        return self.MAP[tool_name]


env = CrabAcrobotEnv()

In [4]:
from dataclasses import dataclass


@dataclass
class Task:
    description: str
    evaluate: callable


task = Task(
    description="apply torques on the actuated joint to swing the free end of the linear chain above a given height while starting from the initial state of hanging downwards.",
    evaluate=lambda env: True,
)

In [19]:
from typing import Generic

from gymnasium import Wrapper
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.spaces import Dict, Space, Text, Tuple


class TaskWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
    def __init__(
        self,
        env: gym.Env[ObsType, ActType],
        task: Task,
        *,
        dict_task_key: str = "task",
    ):
        super().__init__(env)
        self.env = env
        self.task = task

        task_space = Text(500)

        # Observation space in different situations
        if isinstance(env.observation_space, Dict):
            assert dict_task_key not in env.observation_space.keys()
            observation_space = Dict(
                {dict_task_key: task_space, **env.observation_space.spaces}
            )
            self._append_data_func = lambda obs, task: {dict_task_key: task, **obs}
        elif isinstance(env.observation_space, Tuple):
            observation_space = Tuple(env.observation_space.spaces + (task_space,))
            self._append_data_func = lambda obs, task: obs + (task,)
        else:
            observation_space = Dict(obs=env.observation_space, task=task_space)
            self._append_data_func = lambda obs, task: {"obs": obs, "task": task}

        self.observation_space: gym.Space[WrapperObsType] = observation_space

    def reset(
        self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[Dict, dict[str, Any]]:
        """Modifies the :attr:`env` after calling :meth:`reset`, returning a modified
        observation using :meth:`self.observation`."""
        obs, info = self.env.reset(seed=seed, options=options)
        return self.observation(obs), info

    def step(
        self, action: ActType
    ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]:
        observation, reward, terminal, truncated, info = self.step(action)
        reward = self.task.evaluate(self.env)
        return self.observation(observation), reward, terminal, truncated, info

    def observation(self, observation: ObsType):
        """Returns a modified observation.

        Args:
            observation: The :attr:`env` observation

        Returns:
            The modified observation
        """
        return self._append_data_func(observation, self.task.description)

In [20]:
task_env = TaskWrapper(env, task)
o, i = task_env.reset()

In [21]:
o

{'obs': array([ 9.9810892e-01,  6.1470319e-02,  1.0000000e+00, -2.1458303e-05,
        -9.0955026e-02, -7.1539722e-02], dtype=float32),
 'task': 'apply torques on the actuated joint to swing the free end of the linear chain above a given height while starting from the initial state of hanging downwards.'}

In [4]:
import openai

client = openai.Client()
o, _ = env.reset()


result = client.chat.completions.create(
    model="gpt-4-0613",
    messages=[
        {"role": "system", "content": env.get_description()},
        {"role": "user", "content": str(o) + "Tell me next step"},
    ],
    tools=[{"function": tool, "type": "function"} for tool in env.get_action_schema()],
    tool_choice="required",
)
print(result.choices[0].message.tool_calls)

: 

In [14]:
o

array([ 0.999577  ,  0.02908438,  0.9999982 , -0.00189753,  0.08006953,
        0.06967726], dtype=float32)