In [1]:
import pygame
import numpy as np
from time import sleep
from gymnasium import Env, spaces, register, make
import random

pygame 2.5.2 (SDL 2.28.3, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:

class BanditSlipperyWalkEnv(Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None, slip_prob = 0.2):
        self.P = {
            0: {
                0: [(1.0, 0, 0.0, True)],
                1: [(1.0, 0, 0.0, True)]
            },
            1: {
                0: [(0.8, 0, 0.0, True), (0.2, 2, 1.0, True)],
                1: [(0.8, 2, 1.0, True), (0.2, 0, 0.0, True)]
            },
            2: {
                0: [(1.0, 2, 0.0, True)],
                1: [(1.0, 2, 0.0, True)]
            }
        }
        self.size = 3 # The size of the 1D grid
        self.window_size = 512  # The size of the PyGame window

        # We have 3 observations, corresponding to each position in the 1-D grid
        self.observation_space = spaces.Discrete(self.size)

        # We have 2 actions, corresponding to "left" & "right"
        self.action_space = spaces.Discrete(2)

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode
        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None

        # The probability of the slip
        self.slip_prob = slip_prob

    def _get_obs(self):

        return {"agent": self._agent_location, "target": self._target_location}


    def _get_info(self):
        return {
            "distance": abs(self._agent_location - self._target_location)
        }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self._agent_location = 1
        self._target_location = self.size-1
        self._dead_state = 0


        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, info


    def step(self, action):

        prev_location = self._agent_location
        transitions = self.P[prev_location][action]
        probabilities, next_states, rewards, terminals = zip(*transitions)

        # Randomly select a transition based on the probabilities
        index = random.choices(range(len(probabilities)), weights=probabilities, k=1)[0]
        self._agent_location, reward, terminated = next_states[index], rewards[index], terminals[index]

        truncated = False
        observation = self._get_obs()
        info = self._get_info()

        info["log"] = {"current_state": prev_location,
                       "action":action,
                        "next_state": self._agent_location}

        if self.render_mode == "human":
            self._render_frame()

        # Return the required 5-tuple
        return observation, reward, terminated, truncated, info


    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):


        # The size of a single grid square in pixels
        pix_square_size = (
            self.window_size / self.size
        )

        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode(
                (self.window_size, pix_square_size)
            )

        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, pix_square_size))
        canvas.fill((255, 255, 255))


        # First we draw the target
        pygame.draw.rect(
            canvas,
            (0, 255, 0),
            pygame.Rect(
                pix_square_size * np.array([self._target_location, 0]),
                (pix_square_size, pix_square_size),
            ),
        )

        # First we draw the dead state
        pygame.draw.rect(
            canvas,
            (255, 0, 0),
            pygame.Rect(
                pix_square_size * np.array([self._dead_state, 0]),
                (pix_square_size, pix_square_size),
            ),
        )

        # Now we draw the agent
        pygame.draw.circle(
            canvas,
            (0, 0, 255),
            (np.array([self._agent_location, 0]) + 0.5) * pix_square_size,
            pix_square_size / 3,
        )

        # Finally, add some gridlines
        for x in range(self.size + 1):
            pygame.draw.line(
                canvas,
                0,
                (pix_square_size * x, 0),
                (pix_square_size * x, self.window_size),
                width=3,
            )

        if self.render_mode == "human":
            # The following line copies our drawings from `canvas` to the visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # We need to ensure that human-rendering occurs at the predefined framerate.
            # The following line will automatically add a delay to keep the framerate stable.
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )

    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

In [None]:
# Register the custom environment
register(id='BanditSlipperyWalk-v0', entry_point=BanditSlipperyWalkEnv)

# Create and use the environment
environment = make('BanditSlipperyWalk-v0', render_mode="human")

In [None]:
observation, info = environment.reset(seed=42)
for _ in range(10):
    action = environment.action_space.sample()  # this is where you would insert your policy
    observation, reward, terminated, truncated, info = environment.step(action)

    print(info["log"], "\n\n")

    if terminated:
        print("Terminated", "\n\n")

    sleep(2)

    if terminated or truncated:
        observation, info = environment.reset(seed=42)
        sleep(2)

{'current_state': 1, 'action': 0, 'next_state': 0} 


Terminated 


{'current_state': 1, 'action': 1, 'next_state': 2} 


Terminated 


{'current_state': 1, 'action': 0, 'next_state': 0} 


Terminated 


{'current_state': 1, 'action': 1, 'next_state': 0} 


Terminated 


{'current_state': 1, 'action': 1, 'next_state': 0} 


Terminated 


{'current_state': 1, 'action': 0, 'next_state': 2} 


Terminated 


{'current_state': 1, 'action': 1, 'next_state': 2} 


Terminated 


{'current_state': 1, 'action': 0, 'next_state': 0} 


Terminated 


{'current_state': 1, 'action': 0, 'next_state': 2} 


Terminated 


{'current_state': 1, 'action': 1, 'next_state': 2} 


Terminated 


