https://pythonprogramming.net/custom-environment-reinforcement-learning-stable-baselines-3-tutorial/?completed=/saving-and-loading-reinforcement-learning-stable-baselines-3-tutorial/

In [1]:
import gym
from gym import spaces
import numpy as np
import cv2
import random
import time
from collections import deque

# The snake game costum enviroment

In [2]:
def collision_with_apple(apple_position):
    apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
    return apple_position

def collision_with_boundaries(snake_head):
    if snake_head[0]>=500 or snake_head[0]<0 or snake_head[1]>=500 or snake_head[1]<0 :
        return 1
    else:
        return 0

def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0
    
class SnakeEnv(gym.Env):

    def __init__(self):
        
        super(SnakeEnv, self).__init__()
        
        self.snake_obs = 10 * 2
        self.snake_initial_length = 3
        
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=-500, high=500, shape=(5+self.snake_obs,), dtype=np.float32)
        # however long we aspire the snake to be
        self.prev_actions = deque(maxlen = self.snake_obs)  

    def step(self, action):
        
        # Change the head position based on the button direction
        if action == 1:
            self.snake_head[0] += 10
        elif action == 0:
            self.snake_head[0] -= 10
        elif action == 2:
            self.snake_head[1] += 10
        elif action == 3:
            self.snake_head[1] -= 10

        # Increase Snake length on eating apple
        apple_reward = 0.0
        if self.snake_head == self.apple_position:
            self.apple_position = collision_with_apple(self.apple_position)
            self.snake_position.insert(0,list(self.snake_head))
            apple_reward = 2.0 *(len(self.snake_position)-3)
        else:
            self.snake_position.insert(0,list(self.snake_head))
            self.snake_position.pop()
        
        # On collision kill the snake
        if collision_with_boundaries(self.snake_head) == 1 or collision_with_self(self.snake_position) == 1:
            self.done = True
            self.reward = -2.0
        else:
            euclidean_dist_to_apple = np.linalg.norm(np.array(self.snake_head) - np.array(self.apple_position))
            self.total_reward = (self.initial_distance - euclidean_dist_to_apple)/self.initial_distance + apple_reward 
            self.reward = self.total_reward - self.prev_total_reward
            self.prev_total_reward = self.total_reward
        
        observation = self._compute_observation()
        
        return observation, self.reward, self.done, {}

    def reset(self):

        # Initial Snake and Apple position
        self.snake_position = [[250,250],[240,250],[230,250]]
        self.apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
        self.snake_head = [250,250]
        self.initial_distance = np.linalg.norm(np.array(self.snake_head) - np.array(self.apple_position)) + 1.0e-6
        
        self.prev_total_reward = 0.0
        
        # empty actions
        for i in range(self.snake_obs):
            self.prev_actions.append(-1) 
            
        self.done = False
        observation = self._compute_observation()
        
        return observation
    
    def render(self, mode='human'):
        
        # Create image
        self.img = np.zeros((500,500,3),dtype='uint8')        
        # Display Apple
        cv2.rectangle(self.img,(self.apple_position[0],self.apple_position[1]),(self.apple_position[0]+10,self.apple_position[1]+10),(0,0,255),3)
        
        # Display Snake
        for position in self.snake_position:
            cv2.rectangle(self.img,(position[0],position[1]),(position[0]+10,position[1]+10),(0,255,0),3)
            
        # Display collision text
        if collision_with_boundaries(self.snake_head) == 1 or collision_with_self(self.snake_position) == 1:
            font = cv2.FONT_HERSHEY_SIMPLEX
            self.img = np.zeros((500,500,3),dtype='uint8')
            cv2.putText(self.img,'Snake length {}'.format(len(self.snake_position)),(140,250), font, 1,(255,255,255),2,cv2.LINE_AA)
        
        cv2.imshow('a',self.img)
        cv2.waitKey(1)
    
    def _compute_observation(self):
        head_x = self.snake_head[0]
        head_y = self.snake_head[1]
        snake_length = len(self.snake_position)
        apple_delta_x = self.apple_position[0] - head_x
        apple_delta_y = self.apple_position[1] - head_y
        
        snake_len = len(self.snake_position)
        for i in range(1,len(self.snake_position)):
            self.prev_actions.append(self.snake_position[i-1][0] -self.snake_position[i][0]) 
            self.prev_actions.append(self.snake_position[i-1][1] -self.snake_position[i][1]) 
        
        observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
        observation = np.array(observation)
        
        return observation

In [3]:
env = SnakeEnv()
episodes = 50

for episode in range(episodes):
    done = False
    obs = env.reset()
    while not done:
        random_action = env.action_space.sample()
        obs, reward, done, info = env.step(random_action)
        #env.render()
        print('reward',reward)

reward -0.035838481295298825
reward -0.0545014593379072
reward -0.03677418377889932
reward -0.03929703964333506
reward 0.050778976986021365
reward -0.04339655368028854
reward -0.04895613030279258
reward -2.0
reward 0.018482914825336278
reward -0.035848060126326346
reward -0.01784306680064689
reward -2.0
reward 0.037161065112331325
reward 0.03104749594087975
reward -0.038351056492660636
reward -2.0
reward 0.035679243080870386
reward -2.0
reward 0.059989001100165695
reward -2.0
reward 0.026892213878305954
reward -2.0
reward -2.0
reward 0.06546137914360955
reward -2.0
reward -0.0558191538301098
reward -0.05596569928451606
reward -0.013192371543053402
reward -2.0
reward -0.045515870485959335
reward -0.03726470716315363
reward -0.03910206990071974
reward -0.04075406357528903
reward 0.04075406357528903
reward -0.04371620012417811
reward -0.03927629975056743
reward -0.04357215676975609
reward -2.0
reward -2.0
reward -2.0
reward -2.0
reward -0.10431524350195896
reward -2.0
reward -0.0344167169

# Use a stable baselines 3 algorithm to compute the optimal policy

In [4]:
import gym
import os
from stable_baselines3 import PPO



In [5]:
env = SnakeEnv()

In [6]:
model_name = "PPO"
models_dir = "models/" + model_name
logdir = "logs"

In [7]:
if not os.path.exists(models_dir):
    os.makedirs(models_dir)

if not os.path.exists(logdir):
    os.makedirs(logdir)

## Save models at different iterations

In [8]:
env.reset()
if model_name =="PPO":
    model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=logdir)
elif model_name =="A2C":
    model = A2C('MlpPolicy', env, verbose=1, tensorboard_log=logdir)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
TIMESTEPS = 10000
iters = 0
for i in range(100):
    iters += 1
    model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=model_name)
    model.save(f"{models_dir}/{TIMESTEPS*iters}")

Logging to logs\PPO_0
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.15     |
|    ep_rew_mean     | -2.03    |
| time/              |          |
|    fps             | 908      |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.4         |
|    ep_rew_mean          | -2.02       |
| time/                   |             |
|    fps                  | 998         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.013251333 |
|    clip_fraction        | 0.193       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.38       |
|    explained_variance   | -0.755      |
|    lea

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 19          |
|    ep_rew_mean          | -1.46       |
| time/                   |             |
|    fps                  | 1316        |
|    iterations           | 3           |
|    time_elapsed         | 4           |
|    total_timesteps      | 26624       |
| train/                  |             |
|    approx_kl            | 0.014264483 |
|    clip_fraction        | 0.119       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.678      |
|    explained_variance   | 0.518       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0083     |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.018      |
|    value_loss           | 0.0643      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 21.9  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 25.8        |
|    ep_rew_mean          | -1.26       |
| time/                   |             |
|    fps                  | 1261        |
|    iterations           | 4           |
|    time_elapsed         | 6           |
|    total_timesteps      | 49152       |
| train/                  |             |
|    approx_kl            | 0.016824085 |
|    clip_fraction        | 0.121       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.508      |
|    explained_variance   | 0.628       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.000624   |
|    n_updates            | 230         |
|    policy_gradient_loss | -0.016      |
|    value_loss           | 0.0685      |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 28.2

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 31.5        |
|    ep_rew_mean          | -1.15       |
| time/                   |             |
|    fps                  | 1261        |
|    iterations           | 5           |
|    time_elapsed         | 8           |
|    total_timesteps      | 71680       |
| train/                  |             |
|    approx_kl            | 0.022134483 |
|    clip_fraction        | 0.141       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.496      |
|    explained_variance   | 0.534       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0148      |
|    n_updates            | 340         |
|    policy_gradient_loss | -0.0157     |
|    value_loss           | 0.169       |
-----------------------------------------
Logging to logs\PPO_0
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 31.1 

Logging to logs\PPO_0
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 35.5     |
|    ep_rew_mean     | -1.12    |
| time/              |          |
|    fps             | 2217     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 94208    |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 35.2        |
|    ep_rew_mean          | -1.28       |
| time/                   |             |
|    fps                  | 1390        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 96256       |
| train/                  |             |
|    approx_kl            | 0.020413112 |
|    clip_fraction        | 0.137       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.482      |
|    explained_variance   | 0.589       |
|    lea

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 36.6        |
|    ep_rew_mean          | -1.11       |
| time/                   |             |
|    fps                  | 1292        |
|    iterations           | 3           |
|    time_elapsed         | 4           |
|    total_timesteps      | 118784      |
| train/                  |             |
|    approx_kl            | 0.020191263 |
|    clip_fraction        | 0.141       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.436      |
|    explained_variance   | 0.621       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.000828   |
|    n_updates            | 570         |
|    policy_gradient_loss | -0.0151     |
|    value_loss           | 0.0999      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 39.5  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49.3        |
|    ep_rew_mean          | -1.23       |
| time/                   |             |
|    fps                  | 1270        |
|    iterations           | 4           |
|    time_elapsed         | 6           |
|    total_timesteps      | 141312      |
| train/                  |             |
|    approx_kl            | 0.018224586 |
|    clip_fraction        | 0.127       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.417      |
|    explained_variance   | 0.382       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.274       |
|    n_updates            | 680         |
|    policy_gradient_loss | -0.0126     |
|    value_loss           | 0.408       |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 42.8  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 42.3        |
|    ep_rew_mean          | -1.14       |
| time/                   |             |
|    fps                  | 1257        |
|    iterations           | 5           |
|    time_elapsed         | 8           |
|    total_timesteps      | 163840      |
| train/                  |             |
|    approx_kl            | 0.020073323 |
|    clip_fraction        | 0.127       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.415      |
|    explained_variance   | 0.221       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.699       |
|    n_updates            | 790         |
|    policy_gradient_loss | -0.0104     |
|    value_loss           | 1.14        |
-----------------------------------------
Logging to logs\PPO_0
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 42.8 

Logging to logs\PPO_0
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 39.4     |
|    ep_rew_mean     | -1.25    |
| time/              |          |
|    fps             | 2148     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 186368   |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 35          |
|    ep_rew_mean          | -1.21       |
| time/                   |             |
|    fps                  | 1342        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 188416      |
| train/                  |             |
|    approx_kl            | 0.031012382 |
|    clip_fraction        | 0.142       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.39       |
|    explained_variance   | 0.423       |
|    lea

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 38.3       |
|    ep_rew_mean          | -1.41e+05  |
| time/                   |            |
|    fps                  | 1344       |
|    iterations           | 3          |
|    time_elapsed         | 4          |
|    total_timesteps      | 210944     |
| train/                  |            |
|    approx_kl            | 0.02997461 |
|    clip_fraction        | 0.148      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.397     |
|    explained_variance   | 0.546      |
|    learning_rate        | 0.0003     |
|    loss                 | 0.00363    |
|    n_updates            | 1020       |
|    policy_gradient_loss | -0.0175    |
|    value_loss           | 0.142      |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 40.6        |
|    ep_rew_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 48.8        |
|    ep_rew_mean          | -0.786      |
| time/                   |             |
|    fps                  | 1249        |
|    iterations           | 4           |
|    time_elapsed         | 6           |
|    total_timesteps      | 233472      |
| train/                  |             |
|    approx_kl            | 0.018406473 |
|    clip_fraction        | 0.11        |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.353      |
|    explained_variance   | 0.364       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0719      |
|    n_updates            | 1130        |
|    policy_gradient_loss | -0.012      |
|    value_loss           | 0.396       |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 45.8  

In [None]:
env.reset()

model_path = f"{models_dir}/650000.zip"

if model_name =="PPO":
    model = PPO.load(model_path, env=env)
elif model_name =="A2C":
    model = A2C.load(model_path, env=env)

episodes = 50
for ep in range(episodes):
    obs = env.reset()
    done = False
    while not done:
        # pass observation to model to get predicted action
        action, _states = model.predict(obs)
        # pass action to env and get info back
        obs, rewards, done, info = env.step(action)
        
        # show the environment on the screen
        env.render()

In [None]:
cv2.destroyAllWindows()