<a href="https://colab.research.google.com/github/iso-ai/isopro_examples/blob/main/examples/run_cartpole_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLM-based CartPole Reinforcement Learning Agent

This notebook demonstrates how to create and train a Reinforcement Learning agent that uses a Large Language Model (LLM) to make decisions in the CartPole environment.

## Setup

First, let's import the necessary libraries and set up our environment.

In [None]:
!pip install isopro

In [None]:
!pip install iso-adverse

In [None]:
import gymnasium as gym
from isopro.rl.rl_agent import RLAgent
from isopro.rl.rl_environment import LLMRLEnvironment
from stable_baselines3 import PPO
import numpy as np
import anthropic
import logging
from typing import Optional, Dict, Any
from tqdm import tqdm
import json
from datetime import datetime
from google.colab import userdata
import os

# Create a folder to store the results
output_folder = "results"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

## Define the LLMCartPoleWrapper

Now, let's define our `LLMCartPoleWrapper` class, which will integrate the LLM with the CartPole environment.

In [None]:
class LLMCartPoleWrapper(LLMRLEnvironment):
    def __init__(self, agent_prompt, llm_call_limit: int):
        super().__init__(agent_prompt, None)
        self.cartpole_env = gym.make('CartPole-v1')
        self.action_space = self.cartpole_env.action_space
        self.observation_space = self.cartpole_env.observation_space
        self.client = anthropic.Anthropic(api_key=userdata.get("ANTHROPIC_API_KEY"))
        self.llm_call_count = 0
        self.llm_call_limit = llm_call_limit  # Set the maximum number of LLM calls allowed

    def reset(self, **kwargs):
        # Reset the environment and the LLM call count
        self.llm_call_count = 0
        return self.cartpole_env.reset(**kwargs)

    def step(self, action):
        if self.llm_call_count >= self.llm_call_limit:
            # If the LLM call limit is reached, take a default action (e.g., action = 0)
            logging.warning("LLM call limit reached, default action taken")
            return self.cartpole_env.step(0)  # Default action can be customized

        # Otherwise, proceed with the LLM call and increment the counter
        self.llm_call_count += 1
        return self.cartpole_env.step(action)


    def _llm_decision_to_cartpole_action(self, llm_decision):
        if isinstance(llm_decision, (int, np.integer)):
            return llm_decision
        elif isinstance(llm_decision, str):
            return 0 if "left" in llm_decision.lower() else 1
        else:
            raise ValueError(f"Unexpected action type: {type(llm_decision)}")

    def _update_llm(self, observation, reward, done):
        user_message = f"Observation: {observation}, Reward: {reward}, Done: {done}. What action should we take next?"

        messages = self.conversation_history + [
            {"role": "user", "content": user_message},
        ]

        response = self.client.messages.create(
            model="claude-3-opus-20240229",
            max_tokens=150,
            system=self.agent_prompt,
            messages=messages
        )

        ai_response = response.content[0].text
        self.conversation_history.append({"role": "user", "content": user_message})
        self.conversation_history.append({"role": "assistant", "content": ai_response})
        logger.debug(f"LLM updated. AI response: {ai_response}")

## Create and Train the RL Agent

Now, let's create our RL agent and train it using the LLM-based CartPole environment.

The maximum call limit is set to 100 and total_timesteps to 20 to restrict the amount of LLM calls during training.

In [None]:
agent_prompt = """You are an AI trained to play the CartPole game.
Your goal is to balance a pole on a moving cart for as long as possible.
You will receive observations about the cart's position, velocity, pole angle, and angular velocity.
Based on these, you should decide whether to move the cart left or right.
Respond with 'Move left' or 'Move right' for each decision."""

env = LLMCartPoleWrapper(agent_prompt, llm_call_limit=100)
model = PPO("MlpPolicy", env, verbose=1)

logger.info("Starting training")
model.learn(total_timesteps=20)
logger.info("Training completed")

## Test the Trained Agent

Now that we've trained our agent, let's test it for 2 episodes and see how it performs.

In [None]:
test_episodes = 2
results = []

logger.info("Starting test episodes")
for episode in tqdm(range(test_episodes), desc="Test Episodes"):
    obs, _ = env.reset()
    done = False
    total_reward = 0
    episode_length = 0
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, _ = env.step(action)
        total_reward += reward
        episode_length += 1
        done = terminated or truncated

    logger.info(f"Episode {episode + 1} completed. Total reward: {total_reward}, Length: {episode_length}")
    results.append({"episode": episode + 1, "total_reward": total_reward, "length": episode_length})

# Save results to file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(output_folder, f"cartpole_results_{timestamp}.json")
with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)
logger.info(f"Results saved to {output_file}")

# Print summary
average_reward = sum(r['total_reward'] for r in results) / len(results)
average_length = sum(r['length'] for r in results) / len(results)
logger.info(f"Test completed. Average reward: {average_reward:.2f}, Average length: {average_length:.2f}")

## Conclusion

In this notebook, we've demonstrated how to:

1. Set up an LLM-based wrapper for the CartPole environment
2. Train a reinforcement learning agent using this environment
3. Test the trained agent and collect performance metrics

This approach combines the decision-making capabilities of a large language model with the learning process of reinforcement learning, potentially leading to interesting and novel solutions to the CartPole problem.

Feel free to experiment with different prompts, training parameters, or even different environments to see how this approach can be applied in various scenarios!