In [1]:
import os

### LLMs
os.environ["USER_AGENT"] = "Mozilla/5.0"
os.environ['COHERE_API_KEY'] = "<your-api>" # LLM

In [13]:
from typing import List
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_cohere import ChatCohere

# Data Model
class ActionModel(BaseModel):
    action: List[int] = Field(description="For given observation select a set of actions from '0' or '1' or '2' or '3'")
    check: bool = Field(description="check whether the provided actions reaches the Goal or not")
    verify: str = Field(description="Explain how the provided set of actions reaches goal")
# llm
llm = ChatCohere(model='command-r-plus', temperature=0.3)

# parser
parser = PydanticOutputParser(pydantic_object=ActionModel)

# prompt
prompt = PromptTemplate(
    template="""
    Here's a prompt for asking an LLM to play the game based on the given map and actions:

    You are playing a game on a grid map where each state has an observation value. The goal is to navigate from the start (S) to the goal (G), avoiding holes (H) and traversing frozen states (F). Below is the map:
        -------------------------
        |S(0) |F(1) |F(2) |F(3) |
        |F(4) |H(5) |F(6) |H(7) |
        |F(8) |F(9) |F(10)|H(11)|
        |H(12)|F(13)|F(14)|G(15)|
        -------------------------
    Each cell has an observation value, and you will be given the current observation value, which indicates your current state. Your task is to select the next action to reach the goal. The actions you can take are:

    0: Move left
    1: Move down
    2: Move right
    3: Move up
    
    Given your current observation value, Find a set of actions to reach the goal (G) while avoiding holes (H). Please provide your action as a list of number from (0, 1, 2, or 3).
    After providing answer also check whether it solves whether the provided solution is right or wrong, and then provide an explanation how the action list reaches goal.

    for example: (current observation - 9)
    action = [1, 2, 2]
    check = True
    Verify = Start from (9) move down(1) reaches (13) then moves right(2) reaches (14) and again moves right(2) reaches (15). Goal is reached.

    another example: (current observation - 2)
    action = [1, 1, 1, 2]
    check = True
    Verify = Start from (2) moves down(1) reaches (6), moves down(1) reaches (10), moves down (1) reaches (14), moves right(2) reaches (15). Goal is reached.

    another example: (current observation - 0)
    action = [1, 2, 2]
    check = False
    Verify = Start from (0) moves down(1) reaches (4), then moves right(2) reaches (5). Fallen into hole, Game ends.
    Current Observation:{info}\n
    Format Instruction: {format_instructions}\n""",
    input_variables=["info"],
    partial_variables={"format_instructions": parser.get_format_instructions()},
)

chain = prompt | llm | parser

In [14]:
response = chain.invoke({"info": 0})
print('action: ', response.action)
print('check: ', response.check)
print('exp: ', response.verify)

action:  [1, 2, 2]
check:  True
exp:  Starting from observation 0 (S), the agent moves down (1) to reach observation 4, then moves right (2) twice to reach observations 6 and 10 respectively. Finally, moving right (2) again, the agent reaches observation 14 (G), successfully avoiding the holes (H).


In [12]:
import gymnasium as gym
env = gym.make("FrozenLake-v1", desc = ["SFFF", "FHFH", "FFFH", "HFFG"], map_name="4x4", is_slippery=False, render_mode="human")
observation, info = env.reset()
reward = None
action = None
for i in range(len(response.action)):
    action = response.action[i]  # agent policy that uses the observation and info
    observation, reward, terminated, truncated, info = env.step(action)
    if terminated:
        # check truncated or terminated
        print("terminated: {}".format(terminated, truncated))
        observation, info = env.reset()

env.close()

terminated: True
