## Gym Environment
Example following Stable Baseline 3's documentation with a gym environment

In [1]:
import cv2
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv, VecMonitor
from stable_baselines3.common.env_checker import check_env

# Create the environment
env = gym.make("CartPole-v1", render_mode="rgb_array")


check_env(env)

In [1]:
import cv2
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv, VecMonitor

# Create the environment
env = gym.make("CartPole-v1", render_mode="rgb_array")

# Wrap the environment to record video
env = DummyVecEnv([lambda: env])
env = VecMonitor(env)
video_length = 1000  # Length of the video in steps

# Create the video recorder (specify save path and video length)
video_folder = 'videos/'
env = VecVideoRecorder(env, video_folder=video_folder,
                       record_video_trigger=lambda x: x == 0,  # record the first episode
                       video_length=video_length, name_prefix="cartpole_a2c")

# Train the model
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)

# Reset the environment
obs = env.reset()
for i in range(video_length):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)

    # Render each frame (Optional if not viewing)
    env.render("human")

# Close the environment and the video recorder
env.close()
cv2.destroyAllWindows()

Using cpu device


DependencyNotInstalled: moviepy is not installed, run `pip install moviepy`

# Custom environment
Starting code for a snake game, to be converted into a custom gym environment

Below cell implements game logic and rendering with opencv (wasd movement, q to quit)

In [4]:
import numpy as np
import cv2
import random
import time

def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
    score += 1
    return apple_position, score

def collision_with_boundaries(snake_head):
    if snake_head[0]>=500 or snake_head[0]<0 or snake_head[1]>=500 or snake_head[1]<0 :
        return 1
    else:
        return 0

def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0

img = np.zeros((500,500,3),dtype='uint8')
# Initial Snake and Apple position
snake_position = [[250,250],[240,250],[230,250]]
apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
score = 0
prev_button_direction = 1
button_direction = 1
snake_head = [250,250]
while True:
    cv2.imshow('a',img)
    cv2.waitKey(1)
    img = np.zeros((500,500,3),dtype='uint8')
    # Display Apple
    cv2.rectangle(img,(apple_position[0],apple_position[1]),(apple_position[0]+10,apple_position[1]+10),(0,0,255),3)
    # Display Snake
    for position in snake_position:
        cv2.rectangle(img,(position[0],position[1]),(position[0]+10,position[1]+10),(0,255,0),3)
    
    # Takes step after fixed time
    t_end = time.time() + 0.05
    k = -1
    while time.time() < t_end:
        if k == -1:
            k = cv2.waitKey(1)
        else:
            continue
            
    # 0-Left, 1-Right, 3-Up, 2-Down, q-Break
    # a-Left, d-Right, w-Up, s-Down

    if k == ord('a') and prev_button_direction != 1:
        button_direction = 0
    elif k == ord('d') and prev_button_direction != 0:
        button_direction = 1
    elif k == ord('w') and prev_button_direction != 2:
        button_direction = 3
    elif k == ord('s') and prev_button_direction != 3:
        button_direction = 2
    elif k == ord('q'):
        break
    else:
        button_direction = button_direction
    prev_button_direction = button_direction

    # Change the head position based on the button direction
    if button_direction == 1:
        snake_head[0] += 10
    elif button_direction == 0:
        snake_head[0] -= 10
    elif button_direction == 2:
        snake_head[1] += 10
    elif button_direction == 3:
        snake_head[1] -= 10

    # Increase Snake length on eating apple
    if snake_head == apple_position:
        apple_position, score = collision_with_apple(apple_position, score)
        snake_position.insert(0,list(snake_head))

    else:
        snake_position.insert(0,list(snake_head))
        snake_position.pop()
        
    # On collision kill the snake and print the score
    if collision_with_boundaries(snake_head) == 1 or collision_with_self(snake_position) == 1:
        font = cv2.FONT_HERSHEY_SIMPLEX
        img = np.zeros((500,500,3),dtype='uint8')
        cv2.putText(img,'Your Score is {}'.format(score),(140,250), font, 1,(255,255,255),2,cv2.LINE_AA)
        cv2.imshow('a',img)
        cv2.waitKey(0)
        break
        
cv2.destroyAllWindows()

In [2]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import cv2
import random
import time
from collections import deque

SNAKE_LEN_GOAL = 30

def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
    score += 1
    return apple_position, score

def collision_with_boundaries(snake_head):
    if snake_head[0]>=500 or snake_head[0]<0 or snake_head[1]>=500 or snake_head[1]<0 :
        return 1
    else:
        return 0

def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


class SnekEnv(gym.Env):

    def __init__(self):
        super(SnekEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=-500, high=500,
                                            shape=(5+SNAKE_LEN_GOAL,), dtype=np.float32)

    def step(self, action):
        self.prev_actions.append(action)
        cv2.imshow('a',self.img)
        cv2.waitKey(1)
        self.img = np.zeros((500,500,3),dtype='uint8')
        # Display Apple
        cv2.rectangle(self.img,(self.apple_position[0],self.apple_position[1]),(self.apple_position[0]+10,self.apple_position[1]+10),(0,0,255),3)
        # Display Snake
        for position in self.snake_position:
            cv2.rectangle(self.img,(position[0],position[1]),(position[0]+10,position[1]+10),(0,255,0),3)
        
        # Takes step after fixed time
        t_end = time.time() + 0.05
        k = -1
        while time.time() < t_end:
            if k == -1:
                k = cv2.waitKey(1)
            else:
                continue

        button_direction = action
        # Change the head position based on the button direction
        if button_direction == 1:
            self.snake_head[0] += 10
        elif button_direction == 0:
            self.snake_head[0] -= 10
        elif button_direction == 2:
            self.snake_head[1] += 10
        elif button_direction == 3:
            self.snake_head[1] -= 10


        apple_reward = 0
        # Increase Snake length on eating apple
        if self.snake_head == self.apple_position:
            self.apple_position, self.score = collision_with_apple(self.apple_position, self.score)
            self.snake_position.insert(0,list(self.snake_head))
            apple_reward = 10000

        else:
            self.snake_position.insert(0,list(self.snake_head))
            self.snake_position.pop()
        
        # On collision kill the snake and print the score
        if collision_with_boundaries(self.snake_head) == 1 or collision_with_self(self.snake_position) == 1:
            font = cv2.FONT_HERSHEY_SIMPLEX
            self.img = np.zeros((500,500,3),dtype='uint8')
            cv2.putText(self.img,'Your Score is {}'.format(self.score),(140,250), font, 1,(255,255,255),2,cv2.LINE_AA)
            cv2.imshow('a',self.img)
            self.done = True
        


        euclidean_dist_to_apple = np.linalg.norm(np.array(self.snake_head) - np.array(self.apple_position))

        self.total_reward = ((250 - euclidean_dist_to_apple) + apple_reward)/100

        self.reward = self.total_reward - self.prev_reward
        self.prev_reward = self.total_reward

        if self.done:
            self.reward = -10
        info = {}


        head_x = self.snake_head[0]
        head_y = self.snake_head[1]

        snake_length = len(self.snake_position)
        apple_delta_x = self.apple_position[0] - head_x
        apple_delta_y = self.apple_position[1] - head_y

        # create observation:

        observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
        observation = np.array(observation, dtype=np.float32)

        return observation, self.total_reward, self.done, False, info

    def reset(self, seed=0):
        random.seed(seed)
        np.random.seed(seed)

        self.img = np.zeros((500,500,3),dtype='uint8')
        # Initial Snake and Apple position
        self.snake_position = [[250,250],[240,250],[230,250]]
        self.apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
        self.score = 0
        self.prev_button_direction = 1
        self.button_direction = 1
        self.snake_head = [250,250]

        self.prev_reward = 0

        self.done = False

        head_x = self.snake_head[0]
        head_y = self.snake_head[1]

        snake_length = len(self.snake_position)
        apple_delta_x = self.apple_position[0] - head_x
        apple_delta_y = self.apple_position[1] - head_y

        self.prev_actions = deque(maxlen = SNAKE_LEN_GOAL)  # however long we aspire the snake to be
        for i in range(SNAKE_LEN_GOAL):
            self.prev_actions.append(-1) # to create history

        # create observation:
        observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
        observation = np.array(observation, dtype=np.float32)

        return observation, {}

In [3]:
from stable_baselines3.common.env_checker import check_env

env = SnekEnv()
check_env(env)

cv2.destroyAllWindows()

In [4]:
from stable_baselines3 import PPO
out_dir = "./out"

env = SnekEnv()
env.reset()

model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=out_dir)

TIMESTEPS = 10000
for i in range(1, 5):
    model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"run_{i}")
    model.save(out_dir + f"/model_{i}")
    cv2.destroyAllWindows()

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./out\run_1_0


: 