In [None]:
import cv2
import numpy as np
import random
from keras import utils
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Conv2D, MaxPooling2D, Flatten
from collections import deque
from time import time, sleep
from keras.optimizers import Adam
import tensorflow as tf
from PIL import Image
from enum import Enum

utils.disable_interactive_logging()

In [None]:
class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4

In [None]:
class SnakeGameEnv:
    def __init__(self, grid_size=10):
        self.grid_size = grid_size
        self.grid = np.zeros((self.grid_size, self.grid_size), dtype=np.uint8)
        self.reset()
        self.clock_wise = [(0, 1), (1, 0), (0, -1), (-1, 0)]
        self.MOVE_PENALTY = -2
        self.COLLISION_PENALTY = -200
        self.FOOD_REWARD = 100

    def generate_food(self):
        return random.choice([(x, y) for x in range(self.grid_size) for y in range(self.grid_size) if (x, y) not in self.snake])
    
    def step(self, action):
        # 1-Right Turn   2-Left Turn
        idx = self.clock_wise.index(self.direction)
        
        if action == 1:
            next_idx = (idx + 1) % 4
            self.direction = self.clock_wise[next_idx]
        elif action == 2:
            next_idx = (idx - 1) % 4
            self.direction = self.clock_wise[next_idx]

        new_head = (self.snake[0][0]+self.direction[0], self.snake[0][1]+self.direction[1])

        # Check for the collisions
        if (
            new_head in self.snake
            or new_head[0] < 0
            or new_head[0] >= self.grid_size
            or new_head[1] < 0
            or new_head[1] >= self.grid_size
        ):
            self.done = True
            return self.get_image(), self.COLLISION_PENALTY, self.done
        
        # Move the snake
        self.snake.insert(0, new_head)

        # Check if the snake ate the food
        if new_head == self.food:
            self.food = self.generate_food()
            reward = self.FOOD_REWARD
            self.score += 1
        else:
            self.snake.pop()
            reward = self.MOVE_PENALTY

        return self.get_image(), reward, self.done
    
    def reset(self):
        self.snake = [(self.grid_size // 2, self.grid_size // 2)]
        self.direction = (1, 0)
        self.food = self.generate_food()
        self.done = False
        self.score = 0
        return self.get_image()
    
    def render(self):
        img = self.get_image()
        # Resize the RGB image to the desired dimensions
        img_resized = cv2.resize(img, (300, 300), interpolation=cv2.INTER_NEAREST)
        cv2.imshow("image", np.array(img_resized))
        # if cv2.waitKey(50) & 0xFF == ord('q'):
        #     return
    
    def get_image(self):
        # Define color mapping
        colors = {
            1: (255, 0, 0),   # Snake (Green)
            2: (0, 255, 0)    # Food (Red)
        }

        # Create an RGB image
        rgb_frame = np.zeros((self.grid_size, self.grid_size, 3), dtype=np.uint8)

        for x, y in self.snake:     # Fill the Snake cells
            rgb_frame[x, y, :] = colors[1]
        rgb_frame[self.food[0], self.food[1], :] = colors[2]    # Fill food cell

        img = Image.fromarray(rgb_frame, 'RGB')
        return np.array(img)


In [None]:
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=100000)
        self.memory_len = 0
        self.min_replay_size = 1000
        self.replay_frequency = 5
        self.batch_size = 16
        self.gamma = 0.95
        self.epsilon = 1
        self.epsilon_decay = 0.999
        self.epsilon_min = 0.01
        self.model = self.build_model()
        self.target_model = self.build_model()
        self.target_model_update_frequency = 100    # After every 100 episodes
        self.update_target_model()
    
    def build_model(self):
        model = Sequential([
            Conv2D(64, (3, 3), input_shape=self.state_size),
            MaxPooling2D((2, 2)),
            Conv2D(64, (3, 3), activation='relu'),
            MaxPooling2D((2, 2)),
            Flatten(),
            Dense(64, activation='relu'),
            Dense(self.action_size, activation='linear')
        ]) 
        model.compile(loss='mse', optimizer=Adam(learning_rate=0.001))
        return model
    
    def update_target_model(self):
        self.target_model.set_weights(self.model.get_weights())
    
    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return np.random.randint(0, self.action_size)
        return np.argmax(self.model.predict(state)[0])
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        self.memory_len += 1
    
    def replay(self):
        if self.memory_len < self.min_replay_size:
            return
        
        minibatch = random.sample(self.memory, self.batch_size)
        
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                target += self.gamma * np.amax(self.target_model.predict(next_state)[0])
            
            target_f = self.model.predict(state)
            target_f[0][action] = target
            self.model.fit(state, target_f, epochs=1, verbose=0)

In [None]:
env = SnakeGameEnv(grid_size=10)
agent = DQNAgent(state_size=(env.grid_size, env.grid_size, 3), action_size=4)

In [None]:
def train_model(num_episodes=10000):
    rewards_collected = []
    for episode in range(1, num_episodes+1):
        state = env.reset()
        state = state.reshape(-1, *state.shape)/255
        episode_reward = 0
        step_count = 0
        print(f"Episode {episode}", end=" ")

        while not env.done:
            action = agent.act(state)
            next_state, reward, done = env.step(action)
            next_state = next_state.reshape(-1, *next_state.shape)/255

            agent.remember(state, action, reward, next_state, done)
                
            state = next_state
            episode_reward += reward
            step_count += 1
        
        rewards_collected.append(episode_reward)

        agent.replay()
        
        print(f"Reward: {episode_reward} Score: {env.score}")

        # Update target model after every (100) episodes and show the rewards collected
        if episode % agent.target_model_update_frequency == 0:
            # Show avg reward collected for the last 100 episodes
            print(f"Episode: {episode}  Avg. Reward: {sum(rewards_collected[-100:])*0.01}")
            agent.update_target_model()
        
        if agent.epsilon > agent.epsilon_min:
            agent.epsilon *= agent.epsilon_decay
                    
        tf.keras.backend.clear_session()

In [None]:
def test_model(num_episodes=10):
    agent.epsilon = 0

    for episode in range(1, num_episodes + 1):
        state = env.reset()
        state = state.reshape(-1, *state.shape)/255
        episode_reward = 0

        while not env.done:
            # Choose actions greedily (exploit) based on the learned Q-values
            action = agent.act(state)
            print(action)
            next_state, reward, done = env.step(action)
            next_state = next_state.reshape(-1, *next_state.shape)/255

            env.render()
            state = next_state


In [None]:
# agent.model = load_model("snake_model-1694437083.h5")
# test_model()

In [None]:
train_model(num_episodes=5000)
agent.model.save(f"snake_model-{time()}.h5")

In [None]:
# Initialize the game
env = SnakeGameEnv(grid_size=10)
state = env.reset()
action = None
while True:
    # Implement your DQN agent logic here to take actions and update the game state
    # For now, we'll just render the game to visualize it
    env.render()
    key = cv2.waitKey(100)
    # action = random.randint(0, 2)
    if key == ord('1'):
        action = 1
    elif key == ord('2'):
        action = 2
    else:
        action = 0
    if key == 27:
        break  # ESC to exit
    
    _, _, done = env.step(action)

cv2.destroyAllWindows()

### Additional state function from version0

In [None]:
def _is_surrounded_by_body(self, direction):
    start_row, start_col = self.snake[0][0] + direction[0], self.snake[0][1] + direction[1]
    if not 0 <= start_row < self.grid_size or not 0 <= start_col < self.grid_size or self.grid[start_row, start_col] == 1: 
        return False

    moves = [(0, 1), (0, -1), (1, 0), (-1, 0)]
    visited = [[False] * self.grid_size for _ in range(self.grid_size)]        
    queue = deque([(start_row, start_col)])
    
    while queue:
        row, col = queue.popleft()
        visited[row][col] = True
        
        # If we reach boundary then it is not close loop
        if row == 0 or row == self.grid_size - 1 or col == 0 or col == self.grid_size - 1:
            return False 
        
        for dr, dc in moves:
            r, c = row + dr, col + dc
            if 0 <= r < self.grid_size and 0 <= c < self.grid_size and not visited[r][c] and self.grid[r, c] == 0:
                queue.append((r, c))
    
    return True

def _open_area(self, direction):
    start_row, start_col = self.snake[0][0] + direction[0], self.snake[0][1] + direction[1]
    if not 0 <= start_row < self.grid_size or not 0 <= start_col < self.grid_size or self.grid[start_row, start_col] == 1: 
        return 100
    visited = set()
    queue = deque([(start_row, start_col)])

    area = 0
    while queue:
        r, c = queue.popleft()

        if (r, c) not in visited and 0 <= r < self.grid_size and 0 <= c < self.grid_size and self.grid[r, c] != 1:
            visited.add((r, c))
            area += 1

            # Check adjacent cells
            directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
            for dr, dc in directions:
                new_r, new_c = r + dr, c + dc
                queue.append((new_r, new_c))

    return area

In [None]:
# cur_dir = self.direction
# idx = self.clock_wise.index(cur_dir)
# cur_dir_l = self.clock_wise[(idx - 1) % 4]
# cur_dir_r = self.clock_wise[(idx + 1) % 4]

# area_l = self._find_surrounded_area(cur_dir_l)
# area_r = self._find_surrounded_area(cur_dir_r)
# if area_l == 100: area_r = 100
# elif area_r == 100: area_l = 100

# Check if next move head will be surrounded by snake body
# self._is_surrounded_by_body(cur_dir),
# self._is_surrounded_by_body(cur_dir_l),
# self._is_surrounded_by_body(cur_dir_r),

# area_l < area_r,
# area_r < area_l,