In [103]:
import numpy as np
import torch
import pygame

class SnakeGame():
    def __init__(self, n_rooms=2, room_size=(10,10)):
        self.n_rooms = n_rooms
        self.room_size = room_size
        self.board_size = tuple(s*n_rooms+1 for s in room_size)
        self.reset()
    
    def reset(self):
        """Initialize the game board and snake position"""
        self.game_over = False
        self.score = 0
        
        self.spawn_board()
        self.spawn_snake()
        self.spawn_food()
    
    def spawn_board(self):
        self.board = np.zeros(self.board_size)
        
        # Draw edges 
        self.board[self.n_rooms*self.room_size[0],:] = 1
        self.board[:,self.n_rooms*self.room_size[1]] = 1
        self.board[0,:] = 1
        wallposition = 0
        self.board[:,0] = 1
        
        
        
        # Draw walls
        for wall_room_idx in range(1,self.n_rooms):
            wallposition = wall_room_idx*self.room_size[0]
            self.board[wallposition,:] = 1
            # Draw doors
            for door_room_idx in range(self.n_rooms):
                doorposition = door_room_idx*self.room_size[1] + np.random.randint(1, self.room_size[1])
                self.board[wallposition,doorposition] = 0
            
            wallposition = wall_room_idx*self.room_size[1]
            self.board[:,wallposition] = 1
            # Draw doors
            for door_room_idx in range(self.n_rooms):
                doorposition = door_room_idx*self.room_size[0] + np.random.randint(1, self.room_size[0])
                self.board[doorposition,wallposition] = 0
                    
        
    def is_wall(self, position):
        return self.board[position] == 1
        
    def spawn_snake(self):
        start_position = tuple(np.random.randint(0, bound) for bound in self.board_size)
        while self.is_wall(start_position):
            start_position = tuple(np.random.randint(0, bound) for bound in self.board_size)
        self.snake_position = [start_position]
        self.direction = None
        
    def spawn_food(self):
        self.food_position = tuple(np.random.randint(0, bound) for bound in self.board_size)
        while self.is_wall(self.food_position) or self.food_position in self.snake_position: 
            self.food_position = tuple(np.random.randint(0, bound) for bound in self.board_size)
        
    def get_board_state(self):
        # add base board
        state = torch.tensor(self.board)
        # add snake 
        for x,y in self.snake_position: 
            state[x,y] = 2
        # add food
        state[self.food_position[0],self.food_position[1]] = 3  
        return state
        
    def move_snake(self):
        """Move the snake"""
        if self.direction is None: return 
        
        # Get the current snake position
        head_x, head_y = self.snake_position[0]
        
        # Update the snake position based on the current direction
        if self.direction == "right":
            head_y += 1
        elif self.direction == "left": 
            head_y -= 1
        elif self.direction == "up": 
            head_x -= 1
        elif self.direction == "down": 
            head_x += 1
            
        self.snake_position.insert(0, (head_x, head_y))
        
        if not self.is_food_eaten():
            self.snake_position.pop()
        
        self.direction = None
        
    def is_game_over(self):
        head_position = self.snake_position[0]
        
        # Check if the new position is out of bounds
        for position, bound in zip(head_position, self.board_size):
            if position < 0 or position >= bound:
                return True
            
        # Check if the snake has collided with itself
        if head_position in  self.snake_position[1:]: 
            return True
        
        # Check if the snake has collided with walls
        if self.board[head_position] == 1:
            return True
        
        return False
        
    def is_food_eaten(self):
        return self.snake_position[0] == self.food_position    
    
    def render(self, magnify = 20):
        """Render the current state of the game"""
        board = self.get_board_state().numpy()
        
        # Create the window
        screen = pygame.display.set_mode((self.board_size[1]*magnify, self.board_size[0]*magnify))
        # Fill the background with white
        screen.fill((255, 255, 255))
        # Draw board
        for x, row in enumerate(board):
            for y, value in enumerate(row):
                if value == 1:
                    pygame.draw.rect(screen, (0, 0, 0), (y*magnify, x*magnify, magnify, magnify))
        # Draw the snake
        for x,y in self.snake_position:
            pygame.draw.rect(screen, (0, 255, 0), (y*magnify, x*magnify, magnify, magnify))
        # Draw the food
        pygame.draw.rect(screen, (255, 0, 0), (self.food_position[1]*magnify, self.food_position[0]*magnify, magnify, magnify))
        # Update the display
        pygame.display.flip()
        
    def handle_game_over(self):
        """Handle the game over condition"""
        print("Game over! Your score is: ", self.score)
        self.reset()
        
    def handle_inputs(self):
        """Handle user inputs to control snake movement"""
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                self.game_over = True
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_LEFT and self.direction != "right":
                    self.direction = "left"
                elif event.key == pygame.K_RIGHT and self.direction != "left":
                    self.direction = "right"
                elif event.key == pygame.K_UP and self.direction != "down":
                    self.direction = "up"
                elif event.key == pygame.K_DOWN and self.direction != "up":
                    self.direction = "down"
                    
    def run(self):
        pygame.init()
        pygame.display.set_caption("Snake Game")
        clock = pygame.time.Clock()
        self.reset()
        while not self.is_game_over():
            self.render()
            self.handle_inputs()
            self.move_snake()
            if self.is_food_eaten():
                self.score += 1
                self.spawn_food()
            clock.tick(10)
            pygame.display.update()
        self.handle_game_over()
        pygame.quit()

In [106]:
env = SnakeGame(n_rooms=3, room_size=(7,9))
env.run()

Game over! Your score is:  1
