In [1]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import cv2
import random
import time
from collections import deque

SNAKE_LEN_GOAL = 30

def collision_with_apple(apple_position, score):
	apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
	score += 1
	return apple_position, score

def collision_with_boundaries(snake_head):
	if snake_head[0]>=500 or snake_head[0]<0 or snake_head[1]>=500 or snake_head[1]<0 :
		return 1
	else:
		return 0

def collision_with_self(snake_position):
	snake_head = snake_position[0]
	if snake_head in snake_position[1:]:
		return 1
	else:
		return 0


class SnekEnv(gym.Env):

	def __init__(self):
		super(SnekEnv, self).__init__()
		# Define action and observation space
		# They must be gym.spaces objects
		# Example when using discrete actions:
		self.action_space = spaces.Discrete(4)
		# Example for using image as input (channel-first; channel-last also works):
		self.observation_space = spaces.Box(low=-500, high=500,
											shape=(5+SNAKE_LEN_GOAL,), dtype=np.float32)

	def step(self, action):
		self.prev_actions.append(action)
		cv2.imshow('a',self.img)
		#cv2.waitKey(1)
        
		self.img = np.zeros((500,500,3),dtype='uint8')
		# Display Apple
		cv2.rectangle(self.img,(self.apple_position[0],self.apple_position[1]),(self.apple_position[0]+10,self.apple_position[1]+10),(0,0,255),3)
		# Display Snake
		for position in self.snake_position:
			cv2.rectangle(self.img,(position[0],position[1]),(position[0]+10,position[1]+10),(0,255,0),3)

		button_direction = action
		# Change the head position based on the button direction
		if button_direction == 1:
			self.snake_head[0] += 10
		elif button_direction == 0:
			self.snake_head[0] -= 10
		elif button_direction == 2:
			self.snake_head[1] += 10
		elif button_direction == 3:
			self.snake_head[1] -= 10

		# Increase Snake length on eating apple
		if self.snake_head == self.apple_position:
			self.apple_position, self.score = collision_with_apple(self.apple_position, self.score)
			self.snake_position.insert(0,list(self.snake_head))

		else:
			self.snake_position.insert(0,list(self.snake_head))
			self.snake_position.pop()
		
		# On collision kill the snake and print the score
		if collision_with_boundaries(self.snake_head) == 1 or collision_with_self(self.snake_position) == 1:
			font = cv2.FONT_HERSHEY_SIMPLEX
			self.img = np.zeros((500,500,3),dtype='uint8')
			cv2.putText(self.img,'Your Score is {}'.format(self.score),(140,250), font, 1,(255,255,255),2,cv2.LINE_AA)
			cv2.imshow('a',self.img)
			self.terminated = True

		euclidean_dist_to_apple = np.linalg.norm(np.array(self.snake_head) - np.array(self.apple_position))
		self.total_reward = (((len(self.snake_position) - 3)*700) + max(0, 700-euclidean_dist_to_apple))/70 # default length is 3
		self.reward = self.total_reward - self.prev_reward + self.score/10
		self.prev_reward = self.total_reward

		if self.terminated:
			self.reward = -10

		info = {}

		info = {"score": self.score}
		

		head_x = self.snake_head[0]
		head_y = self.snake_head[1]

		snake_length = len(self.snake_position)
		apple_delta_x = self.apple_position[0] - head_x
		apple_delta_y = self.apple_position[1] - head_y

		# create observation:

		observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
		observation = np.array(observation, dtype=np.float32)

		self.truncated = False;

		return observation, self.reward, self.terminated, self.truncated, info

	def reset(self, seed = None):
		super().reset(seed=seed)
		if seed is not None:
			random.seed(seed)
		else:
			random.seed()
        
		self.img = np.zeros((500,500,3),dtype='uint8')
		# Initial Snake and Apple position
		self.snake_position = [[250,250],[240,250],[230,250]]
		random.seed(seed)
		self.apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
		self.score = 0
		self.prev_button_direction = 1
		self.button_direction = 1
		self.snake_head = [250,250]

		self.prev_reward = 0

		self.terminated = False

		head_x = self.snake_head[0]
		head_y = self.snake_head[1]

		snake_length = len(self.snake_position)
		apple_delta_x = self.apple_position[0] - head_x
		apple_delta_y = self.apple_position[1] - head_y

		self.prev_actions = deque(maxlen = SNAKE_LEN_GOAL)  # however long we aspire the snake to be
		for i in range(SNAKE_LEN_GOAL):
			self.prev_actions.append(-1) # to create history

		# create observation:
		observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
		observation = np.array(observation, dtype=np.float32)

		info = {}
    
		return observation, info

	def render(self):
		t_end = time.time() + 0.05
		k = -1
		cv2.imshow('snake', self.img)
		while time.time() < t_end:
			if k == -1:
				k = cv2.waitKey(1)
			else:
				continue

In [2]:
from stable_baselines3.common.callbacks import BaseCallback

class ScoreLoggerCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])
        for info in infos:
            if "score" in info:
                self.logger.record("env/score", info["score"])
        return True

In [3]:
# Load Model

In [None]:
from stable_baselines3 import PPO

models_dir = "/home/prh/Desktop/Local_Prgm_Projects/TetoML/models/1760159974/"

env = SnekEnv()
env.reset()

model_path = f"{models_dir}/17000000.zip"
model = PPO.load(model_path, env=env)

episodes = 5

for ep in range(episodes):
    obs, info = env.reset()
    done = False
    while not done:
        action, _states = model.predict(obs)
        obs, rewards, terminated, truncated, info = env.step(action)
        env.render()
        done = truncated or terminated
        print(rewards)

Exception: Can't get attribute 'FloatSchedule' on <module 'stable_baselines3.common.utils' from '/home/prh/anaconda3/envs/TetoML3/lib/python3.8/site-packages/stable_baselines3/common/utils.py'>
Exception: Can't get attribute 'FloatSchedule' on <module 'stable_baselines3.common.utils' from '/home/prh/anaconda3/envs/TetoML3/lib/python3.8/site-packages/stable_baselines3/common/utils.py'>


Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


qt.qpa.plugin: Could not find the Qt platform plugin "wayland" in "/home/prh/anaconda3/envs/TetoML3/lib/python3.8/site-packages/cv2/qt/plugins"


6.828056669832212
0.01612589668571296
0.009715212689897612
0.0032450779350341463
-0.0032450779350341463
-0.009715212689897612
0.14224270352079138
0.14218169535531544
-0.017713881987844182
0.14119490778362032
0.1410137815395256
0.14080158418037136
0.14055084364748005
0.14025172218876403
0.1398910682145731
-0.034628387897812196
0.13696418140365552
-0.04670253834981253
-0.05563325515603079
0.12877331403965808
0.1266348255510099
-0.0726343891805179
0.11883559923879972
0.11497162353461654
0.11022977689381896
-0.09765181232144471
0.09765181232144471
0.09002083039066378
0.08085983827762888
0.06995846029392006
0.057181166383088566
-0.13468989074108606
0.038184074744608054
0.023451331307922985
0.14187534140555513
0.008893964042648861
0.14285714285714235
0.14285714285714235
0.14285714285714413
0.14285714285714235
0.14285714285714235
0.14285714285714235
0.14285714285714413
-0.05917336605329915
0.05917336605329915
-0.14285714285714413
-0.14285714285714235
-0.14285714285714235
-0.14285714285714235
