In [1]:
import numpy as np
import onnxruntime as ort

In [2]:
HEAD = 4
BODY = 3
FRUIT = 2
EMPTY = 1
WALL = 0

In [3]:
class SnakeEnv:
    HEAD = 4
    BODY = 3
    FRUIT = 2
    EMPTY = 1
    WALL = 0

    UP = 0
    RIGHT = 1
    DOWN = 2
    LEFT = 3
    NONE = 4

    def __init__(self, n_boards, board_size):
        self.board_size = board_size
        self.n_boards = n_boards

        # Initialize the boards and the bodies
        self.boards, self.bodies = self.init_board(n_boards, board_size)

    def init_board(self, n_boards, board_size):
        boards = np.ones((n_boards, board_size, board_size)) * self.EMPTY
        bodies = [[] for _ in range(n_boards)]
        
        # Add walls
        boards[:, [0, -1], :] = self.WALL
        boards[:, :, [0, -1]] = self.WALL
        
        for i in range(n_boards):
            # Add head
            available = np.argwhere(boards[i] == self.EMPTY)
            ind = available[np.random.choice(len(available))]
            boards[i, ind[0], ind[1]] = self.HEAD
            
            # Add fruit
            available = np.argwhere(boards[i] == self.EMPTY)
            ind = available[np.random.choice(len(available))]
            boards[i, ind[0], ind[1]] = self.FRUIT
        
        return boards, bodies

    def move(self, actions):
        n_boards = len(self.boards)
        
        # Find heads in the boards
        heads = np.argwhere(self.boards == self.HEAD)
        actions = np.array(actions)
        
        # Calculate action offsets
        dx = np.zeros(len(actions))
        dx[np.where(actions == self.UP)[0]] = 1
        dx[np.where(actions == self.DOWN)[0]] = -1
        dy = np.zeros(len(actions))
        dy[np.where(actions == self.RIGHT)[0]] = 1
        dy[np.where(actions == self.LEFT)[0]] = -1
        offset = np.hstack((np.zeros_like(actions)[:, None], dx[:, None], dy[:, None]))
        
        # Calculate new head positions
        new_heads = (heads + offset).astype(int)
        
        # Check for collisions with walls
        walls = np.argwhere(self.boards == self.WALL)
        walls_code = np.sum(walls * ((np.max(walls) + 1) ** np.arange(walls.shape[1])), axis=-1)
        new_heads_code = np.sum(new_heads * ((np.max(walls) + 1) ** np.arange(walls.shape[1])), axis=-1)
        hit_wall = np.isin(new_heads_code, walls_code)
        new_heads[hit_wall] = heads[hit_wall]
        
        # Fruits per board
        fruits = np.argwhere(self.boards == self.FRUIT)
        fruits_eaten_bool = np.any(np.all(new_heads[:, None, 1:] == fruits[:, 1:], axis=-1), axis=1)
        
        # Update bodies and boards
        for i in range(n_boards):
            self.bodies[i].insert(0, heads[i, 1:])
            self.boards[i][np.where(self.boards[i] == self.BODY)] = self.EMPTY
            
            if not fruits_eaten_bool[i]:
                self.bodies[i].pop()
            
            if self.bodies[i]:
                body = np.array(self.bodies[i])
                self.boards[i][body[:, 0], body[:, 1]] = self.BODY
            
            self.boards[i][new_heads[i, 1], new_heads[i, 2]] = self.HEAD
            
            if fruits_eaten_bool[i]:
                available = np.argwhere(self.boards[i] == self.EMPTY)
                if len(available) == 0:
                    self.boards[i], self.bodies[i] = self.init_board(1, self.board_size)
                else:
                    ind = available[np.random.choice(len(available))]
                    self.boards[i][ind[0], ind[1]] = self.FRUIT

    def get_boards(self):
        return self.boards

    def to_state(self):
        num_classes = 5  # [WALL, EMPTY, FRUIT, BODY, HEAD]
        # One-hot encode the board representation
        state = np.eye(num_classes)[self.boards.astype(int)]
        state = state[..., 1:]  # Remove the first category (WALL)
        return state

In [10]:
def get_action(ort_session, state):
    input_name = ort_session.get_inputs()[0].name
    output_name = ort_session.get_outputs()[0].name
    input = state.astype(np.float32)
    outputs = ort_session.run([output_name], {input_name: input})
    return np.argmax(outputs[0])

In [11]:
onnx_model_path = "hybrid_model.onnx"
session = ort.InferenceSession(onnx_model_path)
env = SnakeEnv(1, 7)
for _ in range(10):
    state = env.to_state()
    action = get_action(session, state)
    env.move([action])
    print(env.get_boards()[0])


[[0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 1. 2. 1. 1. 0.]
 [0. 1. 1. 3. 4. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 1. 2. 4. 1. 0.]
 [0. 1. 1. 1. 3. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 2. 1. 1. 1. 0.]
 [0. 1. 1. 4. 3. 1. 0.]
 [0. 1. 1. 1. 3. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 2. 4. 1. 1. 0.]
 [0. 1. 1. 3. 3. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 2. 1. 1. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 1. 4. 3. 1. 1. 0.]
 [0. 1. 1. 3. 3. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 2. 1. 1. 1. 0.]
 [0. 1. 4. 1. 1. 1. 0.]
 [0. 1. 3. 3. 1. 1. 0.]
 [0. 1. 1. 3. 1. 1. 0.]
 [0. 1. 1. 1. 1. 1. 0.]
 [0. 0. 0. 

In [1]:
%load_ext autoreload
%autoreload 2

In [12]:
from utils.snake import SnakeEnv, HybridAgent

env = SnakeEnv(1, 7)
hyb = HybridAgent(env.boards, "models/hybrid_model.onnx")
for _ in range(1000):
    action = hyb.get_actions(env.boards)
    env.move(action)
    print(action)

[[2]]
[[2]]
[[2]]
[[3]]
[[0]]
[[0]]
[[0]]
[[1]]
[[2]]
[[2]]
[[3]]
[[3]]
[[2]]
[[3]]
[[1]]
[[3]]
[[0]]
[[3]]
[[2]]
[[1]]
[[1]]
[[1]]
[[0]]
[[3]]
[[0]]
[[3]]
[[0]]
[[1]]
[[1]]
[[2]]
[[2]]
[[3]]
[[1]]
[[3]]
[[3]]
[[3]]
[[0]]
[[0]]
[[1]]
[[2]]
[[1]]
[[2]]
[[1]]
[[2]]
[[3]]
[[3]]
[[0]]
[[0]]
[[1]]
[[0]]
[[3]]
[[4]]
[[3]]
[[2]]
[[2]]
[[2]]
[[1]]
[[3]]
[[0]]
[[1]]
[[0]]
[[1]]
[[0]]
[[1]]
[[2]]
[[3]]
[[2]]
[[3]]
[[2]]
[[3]]
[[0]]
[[0]]
[[1]]
[[1]]
[[0]]
[[1]]
[[1]]
[[0]]
[[3]]
[[3]]
[[3]]
[[3]]
[[2]]
[[4]]
[[0]]
[[1]]
[[2]]
[[2]]
[[2]]
[[3]]
[[0]]
[[0]]
[[1]]
[[1]]
[[1]]
[[2]]
[[4]]
[[1]]
[[2]]
[[4]]
[[2]]
[[3]]
[[0]]
[[3]]
[[1]]
[[3]]
[[3]]
[[0]]
[[3]]
[[0]]
[[0]]
[[1]]
[[2]]
[[2]]
[[2]]
[[3]]
[[0]]
[[0]]
[[1]]
[[1]]
[[0]]
[[2]]
[[2]]
[[2]]
[[1]]
[[1]]
[[0]]
[[3]]
[[3]]
[[3]]
[[2]]
[[3]]
[[1]]
[[0]]
[[0]]
[[0]]
[[3]]
[[2]]
[[1]]
[[1]]
[[1]]
[[2]]
[[3]]
[[0]]
[[3]]
[[2]]
[[2]]
[[2]]
[[1]]
[[0]]
[[0]]
[[0]]
[[0]]
[[3]]
[[1]]
[[2]]
[[2]]
[[2]]
[[3]]
[[0]]
[[0]]
[[1]]
[[2]]
[[2]]
[[2]]
[[3]]
[[0]