In [1]:
import numpy as np
import torch as t
import gymnasium as gym

from stable_baselines3 import DQN, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement, CheckpointCallback
from stable_baselines3.common.env_checker import check_env


In [2]:

class SimpleGridEnv(gym.Env):

    def __init__(self, grid_size=10, target_pos=5):
        super(SimpleGridEnv, self).__init__()

        self.grid_size = grid_size
        self.target_pos = target_pos
        self.current_pos = None

        # Action space: 0 for left, 1 for right
        self.action_space = gym.spaces.Discrete(2)

        # Observation space: current position on the grid
        self.observation_space = gym.spaces.Box(low=0, high=grid_size, shape=(1,), dtype=np.float32)

        # Reset the environment
        self.reset()

    def reset(self, seed=None):
        # Initialize the agent at a random position on the grid
        self.current_pos = np.random.randint(0, self.grid_size)
        return np.array([self.current_pos], dtype=np.float32), {}

    def step(self, action):
        # Take the action (0 for left, 1 for right)
        if action == 0:
            self.current_pos = max(0, self.current_pos - 1)
        elif action == 1:
            self.current_pos = min(self.grid_size - 1, self.current_pos + 1)

        # Reward: -1 for each step, +10 if the target position is reached
        reward = -1 if self.current_pos != self.target_pos else 10

        # Done: True if the agent reaches the target position
        done = self.current_pos == self.target_pos

        # Information: Additional information for debugging or learning purposes
        info = {}

        # Return the next observation, reward, done, and info
        # obs, reward, terminated, truncated, info
        return np.array([self.current_pos], dtype=np.float32), reward, False, done, info

    def render(self, mode='human'):
        # Print the current state (for human-readable output)
        print(f"Current Position: {self.current_pos}")

In [3]:
env = SimpleGridEnv(grid_size=10, target_pos=5)
check_env(env)