# Environment Definition

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import cv2
import random
import time
from collections import deque

SNAKE_LEN_GOAL = 30

def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
    score += 1
    return apple_position, score

def collision_with_boundaries(snake_head):
    if snake_head[0]>=500 or snake_head[0]<0 or snake_head[1]>=500 or snake_head[1]<0 :
        return 1
    else:
        return 0

def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


class SnekEnv(gym.Env):
    
    def __init__(self):
        super(SnekEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=-500, high=500,
                                            shape=(5+SNAKE_LEN_GOAL+SNAKE_LEN_GOAL*2,), dtype=np.float32)
        self.rendering = False;
    
    def step(self, action):
        self.prev_actions.append(action)
        if self.rendering: cv2.imshow('a',self.img)
        #cv2.waitKey(1)
        
        self.img = np.zeros((500,500,3),dtype='uint8')
        # Display Apple
        cv2.rectangle(self.img,(self.apple_position[0],self.apple_position[1]),(self.apple_position[0]+10,self.apple_position[1]+10),(0,0,255),3)
        # Display Snake
        if self.rendering: 
            for position in self.snake_position:
                cv2.rectangle(self.img,(position[0],position[1]),(position[0]+10,position[1]+10),(0,255,0),3)

        
        button_direction = action

        # Prevent 180Â° reversal (don't allow opposite of last direction)
        if (action == 0 and self.prev_button_direction == 1) or \
           (action == 1 and self.prev_button_direction == 0) or \
           (action == 2 and self.prev_button_direction == 3) or \
           (action == 3 and self.prev_button_direction == 2):
            action = self.prev_button_direction
        self.prev_button_direction = action
        
        # Change the head position based on the button direction
        if button_direction == 1:
            self.snake_head[0] += 10
        elif button_direction == 0:
            self.snake_head[0] -= 10
        elif button_direction == 2:
            self.snake_head[1] += 10
        elif button_direction == 3:
            self.snake_head[1] -= 10



        dist_prev = self.prev_distance_to_apple
        dist_curr = np.linalg.norm(np.array(self.snake_head) - np.array(self.apple_position))
        reward = 0.01

        # On eating apple: increase snake length and reward, otherwise reward if it got closer
        if self.snake_head == self.apple_position:
            self.apple_position, self.score = collision_with_apple(self.apple_position, self.score)
            self.snake_position.insert(0, list(self.snake_head))
            reward += 3.0
        else:
            self.snake_position.insert(0, list(self.snake_head))
            self.snake_position.pop()
            if dist_curr < dist_prev:
                reward += 0.1
            else:
                reward -= 0.1

        # On collision kill snake and punish, render game over screen if rendering
        if collision_with_boundaries(self.snake_head) or collision_with_self(self.snake_position):
            if self.rendering:
                font = cv2.FONT_HERSHEY_SIMPLEX
                self.img = np.zeros((500,500,3),dtype='uint8')
                cv2.putText(self.img,'Your Score is {}'.format(self.score),(140,250), font, 1,(255,255,255),2,cv2.LINE_AA)
                cv2.imshow('a',self.img)
            self.terminated = True
            reward -= 5    # Death penalty

        # Punish for being too close to tail
        for seg in self.snake_position[1:]:
            if np.linalg.norm(np.array(self.snake_head) - np.array(seg)) < SNAKE_LEN_GOAL:
                reward -= 0.05
            
        self.prev_distance_to_apple = dist_curr    # 
        reward -= 0.005   # slight time penalty
        self.reward = reward / 5.0
    
        info = {}
    
        info = {"score": self.score}
        
    
        head_x = self.snake_head[0]
        head_y = self.snake_head[1]
    
        snake_length = len(self.snake_position)
        apple_delta_x = self.apple_position[0] - head_x
        apple_delta_y = self.apple_position[1] - head_y


        
        # create observation:
        max_tail_len = SNAKE_LEN_GOAL
        tail_obs = []
        for i in range(1, max_tail_len + 1):
            if i < len(self.snake_position):
                seg = self.snake_position[i]
                tail_obs.append(seg[0] - self.snake_head[0])
                tail_obs.append(seg[1] - self.snake_head[1])
            else:
                tail_obs.extend([0, 0])  # padding
        
    
        observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions) + tail_obs
        observation = np.array(observation, dtype=np.float32)
    
        self.truncated = False;
    
        return observation, self.reward, self.terminated, self.truncated, info
    
    def reset(self, seed = None, render = False):
        if render: self.rendering = True
        
        super().reset(seed=seed)
        if seed is not None:
            random.seed(seed)
        else:
            random.seed()
        
        self.img = np.zeros((500,500,3),dtype='uint8')
        # Initial Snake and Apple position
        self.snake_position = [[250,250],[240,250],[230,250]]
        
        self.apple_position = [random.randrange(1,50)*10,random.randrange(1,50)*10]
        self.score = 0
        self.prev_button_direction = 1
        self.button_direction = 1
        self.snake_head = [250,250]

        self.prev_distance_to_apple = np.linalg.norm(np.array(self.snake_head) - np.array(self.apple_position))
    
        self.prev_reward = 0
    
        self.terminated = False
    
        head_x = self.snake_head[0]
        head_y = self.snake_head[1]
    
        snake_length = len(self.snake_position)
        apple_delta_x = self.apple_position[0] - head_x
        apple_delta_y = self.apple_position[1] - head_y
    
        self.prev_actions = deque(maxlen = SNAKE_LEN_GOAL)  # however long we aspire the snake to be
        for i in range(SNAKE_LEN_GOAL):
            self.prev_actions.append(-1) # to create history
    
        # create observation:
        max_tail_len = SNAKE_LEN_GOAL
        tail_obs = []
        for i in range(1, max_tail_len + 1):
            if i < len(self.snake_position):
                seg = self.snake_position[i]
                tail_obs.append(seg[0] - self.snake_head[0])
                tail_obs.append(seg[1] - self.snake_head[1])
            else:
                tail_obs.extend([0, 0])  # padding
        
    
        observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions) + tail_obs
        observation = np.array(observation, dtype=np.float32)
    
        info = {}
    
        return observation, info
    
    def render(self):
        t_end = time.time() + 0.05
        k = -1
        cv2.imshow('snake', self.img)
        while time.time() < t_end:
            if k == -1:
                k = cv2.waitKey(1)
            else:
                continue

## Environment Test: Random Actions

# Training

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
import os
import time

models_dir = f"models/{int(time.time())}/"
logdir = f"logs/{int(time.time())}/"

if not os.path.exists(models_dir):
	os.makedirs(models_dir)

if not os.path.exists(logdir):
	os.makedirs(logdir)

In [None]:
from stable_baselines3.common.monitor import Monitor

#env = SnekEnv()
env = make_vec_env(SnekEnv, n_envs=16, seed=0, vec_env_cls=DummyVecEnv)
env.reset()

In [None]:
policy_kwargs = dict(
    net_arch=[dict(pi=[256, 256, 128], vf=[256, 256, 128])]
)

#model = PPO('MlpPolicy', env, device='cpu', verbose=1, tensorboard_log=logdir)
#model = PPO('MlpPolicy', env, device='cuda', verbose=1, n_steps=512, n_epochs=5, tensorboard_log=logdir)
model = PPO(
    "MlpPolicy",
    env,
    policy_kwargs=policy_kwargs,
    n_steps=1024,
    batch_size=4096,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    ent_coef=0.01,
    learning_rate=3e-4,
    clip_range=0.2,
    device="cuda",
    verbose=1,
    tensorboard_log=logdir,
    seed=0
)

In [None]:
print(model.policy)

In [None]:
#for k, v in model.__dict__.items():
#    print(k, v)

In [None]:
from stable_baselines3.common.callbacks import BaseCallback

class ScoreLoggerCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])
        for info in infos:
            if "score" in info:
                self.logger.record("env/score", info["score"])
        return True

In [None]:
TIMESTEPS = 100000
iters = 0
logcb = ScoreLoggerCallback()



while True:    # Train indefinitely
    iters += 1
    model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO", callback=logcb, progress_bar=True)
    model.save(f"{models_dir}/{TIMESTEPS*iters}")

# Load and Test Model

In [None]:
from stable_baselines3 import PPO

models_dir = "/home/prh/Desktop/Local_Prgm_Projects/TetoML/Exploration/Round 2/models/1760239689/"

env = SnekEnv()
env.reset(render = True)

model_path = f"{models_dir}/1300000.zip"
model = PPO.load(model_path, env=env)

episodes = 1

for ep in range(episodes):
    obs, info = env.reset()
    done = False
    while not done:
        action, _states = model.predict(obs)
        obs, rewards, terminated, truncated, info = env.step(action)
        env.render()
        done = truncated or terminated
        print(rewards)