In [16]:
config = {
    "n_player": 1,
    "board_width": 8,
    "board_height": 6,
    "n_beans": 5,
    "max_step": 50,
}

test = mySnake(config)
state = test.reset()

for i in range(100):
    state, _ = test.step(np.random.choice([4, 1, 2, 3], 2))
    plt.imshow(state.reshape(6, 8))
    plt.show()
    display.display(plt.gcf())
    display.clear_output(wait=True)

<Figure size 640x480 with 0 Axes>

In [51]:
import numpy as np
import matplotlib.pyplot as plt


class Snake():
    """
    Some information about the game:

    1. I/O
    input of the game: [action]
    common return [state, done, reward]

    2. update/step
    update is divided into two parts, doing action and after action. many flags are set for after actions.

    3. some constrains
    supported number of snake : 2
    minimum
    """

    def __init__(self, config):

        # configuration
        self.n_player = config["n_player"]
        self.board_width = config["board_width"]
        self.board_height = config["board_height"]

        self.n_beans = config["n_beans"]
        self.max_step = config["max_step"]

        self.check_config()

        # storage
        self.state = np.zeros(self.board_width * self.board_height, dtype=int)  # stored as a matrix, as it is fixed
        self.__available_state = set(np.arange(self.state.size))  # actually it is an index of the state
        self.snake = [[] for _ in range(self.n_player)]  # stored as a list, as it vary in length in different condition
        self.food = []  # same as above
        self.last_snake = [[] for _ in range(self.n_player)]

        # print
        print("Now please use env.reset() to reset the game")

        # flags
        self.flag_init = False
        self.flag_die = [False for _ in range(self.n_player)]
        self.flag_food = [False for _ in range(self.n_beans)]  # the size of the food is fixed
        self.done = False

        # counter
        self.counter = 0

    def reset(self):
        """
        reset the game

        In specific, two steps
            1. place snakes
            2. place food

        :return: only [state]
        """
        # reset the all storage
        self.state = np.zeros(self.board_height * self.board_width)
        self.__available_state = set(np.arange(self.state.size))
        self.snake = [[] for _ in range(self.n_player)]
        self.food = []
        self.reward = [[] for _ in range(self.n_player)]

        # reset the snake
        for i in range(self.n_player):
            # the initial length of the snake is 3
            self.snake[i] = []
            self.snake[i].append([0, 2 * i])
            self.snake[i].append([0, 2 * i + 1])
            self.snake[i].append([1, 2 * i + 1])  # the initial position of the snake

            index = [
                0 + 2 * i,
                0 + 2 * i + 1,
                1 * self.board_width + 2 * i + 1,
            ]

            for idx in index:
                self.__available_state.remove(idx)
                self.state[idx] = i + 1

        # reset the food
        locations_index = np.random.choice(list(self.__available_state), self.n_beans, replace=False)
        for index in locations_index:
            self.food.append([index // self.board_width, index % self.board_width])
            self.__available_state.remove(index)
            self.state[index] = -1  # food's state is -1

        # reset the done
        self.done = False
        self.counter = 0

        return self.state

    def step(self, actions):
        """
        0. verify the action
        1. update the snake 1.1 update the body 1.2 update the head
        2. flag
        :param action: has the shape of [action1,action2,....]
        :return:
        """
        # TODO: need to finish it
        for i in range(self.n_player):
            self.last_snake[i] = self.snake[i].copy()

        tail = [[] for _ in range(self.n_player)]
        # step0: verify the action's correctness
        for single_action in actions:
            assert 1 <= single_action <= 4, 'the action {} is not in the set of [1,2,3,4]'.format(single_action)

        # step1: update the snake
        # first update the body and then update the head ########

        for i, single_snake in enumerate(self.snake):
            # now single_snake is a list and i+1 is the id of it
            # action means: 1 go left, 2 go up, 3 go right, 4 go down
            # update the body

            # # restore state
            # for single_body in single_snake:
            #     idx_body_in_state = single_body[0] * self.board_width + single_body[1]
            #     self.state[idx_body_in_state] = 0
            #     self.__available_state.add(idx_body_in_state)

            tail[i] = single_snake.pop(-1)  # remove the last one

            # update the head
            head = single_snake[0].copy()
            if actions[i] == 1:  # go left
                head[1] -= 1
            elif actions[i] == 2:  # go up
                head[0] -= 1
            elif actions[i] == 3:  # go right
                head[1] += 1
            elif actions[i] == 4:  # go  down
                head[0] += 1

            boundary_check_head = self.boundaryCheck(row=head[0], col=head[1])
            head = [boundary_check_head // self.board_width, boundary_check_head % self.board_width]
            single_snake.insert(0, head)

            self.snake[i] = single_snake

        # first check what he has eaten or ate nothing
        #  what we need to check to get the reward? 1. if the snake ate its body 2. if collide with each other  3. if it ate any food.
        if self.n_player == 2:  # assume 2 as the maximum of players
            if self.snake[0][0] == self.snake[1][0]:
                self.flag_die = [True, True]

        for i, single_snake in enumerate(self.snake):
            idx_head_in_state = single_snake[0][0] * self.board_width + single_snake[0][1]  # head's idx in state

            if self.state[idx_head_in_state] == 0:  # ate nothing
                self.reward[i] = 0

            elif self.state[idx_head_in_state] == -1:  # ate food
                # reward set to 1, you need to create a new food
                self.reward[i] = 1
                self.snake[i].append(tail[i])  # add the tail back
                # find the idx of the food
                idx_food_in_state = idx_head_in_state
                idx_food_in_map = [idx_head_in_state // self.board_width, idx_head_in_state % self.board_width]
                idx_food_change_in_food_list = self.food.index(idx_food_in_map)
                self.flag_food[
                    idx_food_change_in_food_list] = True  # set the flag of the lost food to be true, means it need to be replaced

            else:  # ate body
                self.flag_die[i] = True
                self.reward[i] = 3 - len(single_snake)

        # handle flag
        self.handleFlag()
        self.flag_die = [False for _ in range(self.n_player)]
        self.flag_food = [False for _ in range(self.n_beans)]
        # construct the new state
        self.state = np.zeros(self.board_height * self.board_width)
        self.__available_state = set(range(self.board_height * self.board_width))
        for i, single_snake in enumerate(self.snake):
            for single_body in single_snake:
                idx_body_in_state = single_body[0] * self.board_width + single_body[1]
                self.state[idx_body_in_state] = i + 1
        for single_food in self.food:
            idx_food_in_state = single_food[0] * self.board_width + single_food[1]
            self.state[idx_food_in_state] = -1
            self.__available_state.remove(idx_food_in_state)

        return self.state, self.reward

        # in the edn handle flag

    def check_config(self):
        """
        TODO: need to finish it, do it in the end
        check the configuration
        :return: None
        """
        assert 0 < self.n_player < 3, "only support 2-player game"
        assert self.board_width >= 8, "board width should be larger than 8"
        assert self.board_height >= 6, "board height should be larger than 6"
        assert self.n_beans >= 5, "number of beans should be larger than 5"
        assert self.max_step >= 50, "max step should be larger than 50"

    def handleFlag(self):
        """

        TODO: finish it
        update the state according to the flag
        :return:
        """
        # what we need to generate a new state?
        # 1. new self.snake
        # 2. new self.food

        # 1. new self.snake
        for i, single_flag_die in enumerate(self.flag_die):
            if single_flag_die == False:
                continue
            # now the snake is dead
            # steps to do: 1.remove the snake from the state 2. respawn it in the state
            # 1. remove the snake from the state
            idx_snake_in_state = []
            for block in self.last_snake[i]:
                idx_snake_in_state.append(block[0] * self.board_width + block[1])
            assert np.all(self.state[idx_snake_in_state] == i + 1), "the snake is not in the state"
            for idx in idx_snake_in_state:
                self.state[idx] = 0

        for i, single_flag_die in enumerate(self.flag_die):
            if single_flag_die == False:
                continue
            # 2. respawn it in the state
            # first find the available state,just find a place to respawn it
            for start_place in range(self.state.size):
                idx1 = self.boundaryCheck(start_place)
                idx2 = self.boundaryCheck(start_place + 1)
                idx3 = self.boundaryCheck(start_place + 1 + self.board_width)
                # print(self.state.reshape(self.board_height, self.board_width))
                if self.state[idx1] <= 0:
                    if self.state[idx2] <= 0:
                        if self.state[idx3] <= 0:

                            # this is the place we spawn the snake
                            # first report food
                            for idx in [idx1, idx2, idx3]:
                                if self.state[idx] == -1:
                                    self.reportFood(idx)

                            # then spawn the snake
                            idx_in_state = [idx1, idx2, idx3]
                            idx_in_matrix = [[idx_in_state[0] // self.board_width, idx_in_state[0] % self.board_width],
                                             [idx_in_state[1] // self.board_width, idx_in_state[1] % self.board_width],
                                             [idx_in_state[2] // self.board_width, idx_in_state[2] % self.board_width]]

                            self.snake[i] = []
                            for single_idx_in_matrix in idx_in_matrix:
                                self.snake[i].append(single_idx_in_matrix)
                            break

        # 2. new self.food
        # find available state
        temp_available = set(np.arange(self.state.size))
        for single_snake in self.snake:
            for single_block in single_snake:
                temp_available.remove(single_block[0] * self.board_width + single_block[1])
        for single_food in self.food:  # no matter it is eaten or not, we can to remove it from the available state,
            try:
                temp_available.remove(single_food[0] * self.board_width + single_food[1])
            except:
                pass  # if eaten, then the block is occupied by the snake, if not, then it is occupied by the food

        num_resapwn_food = np.sum(self.flag_food)
        locations_index = np.random.choice(list(temp_available), num_resapwn_food, replace=False)
        local_counter = 0
        for i in range(len(self.food)):
            if self.flag_food[i] == True:
                this_index = locations_index[local_counter]
                self.food[i] = [this_index // self.board_width, this_index % self.board_width]
                self.__available_state.add(this_index)
                self.state[this_index] = -1  # food's state is -1
                local_counter += 1

    # support functions
    def reportFood(self, idx):
        # have detected a food in this place. 1. report it to flag, set it to snake's body
        idx_food_in_map = [idx // self.board_width, idx % self.board_width]
        idx_food_change_in_food_list = self.food.index(idx_food_in_map)

        self.flag_food[
            idx_food_change_in_food_list] = True  # set the flag of the lost food to be true, means it need to be replaced

    def spwanSnake(self, i):
        '''
        # TODO: finish it
        :param i:
        :return:
        '''

    def boundaryCheck(self, idx=None, row=None, col=None):
        if idx == None and row == None and col == None:
            raise ValueError("You must specify one of the three parameters")
        if idx != None:
            row = idx // self.board_width
            col = idx % self.board_width
        if row >= 6:
            row = abs(row % 6)
        if col >= 8:
            col = abs(col % 8)

        if col < 0:
            col = 8 + col
        if row < 0:
            row = 6 + row
        new_idx = row * self.board_width + col
        if new_idx >= self.state.size:
            raise ValueError("the new idx is larger than the state size")
        return new_idx


### Provide your code here
import torch
from maai_cwork.examples.common import utils
from maai_cwork.env.chooseenv import make
from maai_cwork.env import snakes
from maai_cwork.run_utils import get_players_and_action_space_list, run_game
from torch import nn
from maai_cwork.env.chooseenv import make
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import torchviz


class SnakeAgent():
    def __init__(self, env):
        # set hyperparameters
        self.max_episodes = 200
        self.max_actions = 50
        self.gamma = 0.9
        self.exploration_rate = 0.5
        self.exploration_decay = self.exploration_rate / self.max_episodes
        # get envirionment
        self.env = env
        self.last_state = torch.tensor(self.env.reset(), dtype=torch.double)
        self.__nn()

    def __nn(self):
        # define the neural network model for snake 1
        self.net1 = nn.Sequential(nn.Linear(6 * 8, 256),
                                  nn.Linear(256, 256),
                                  nn.Linear(256, 128),
                                  nn.Linear(128, 64),
                                  nn.Linear(64, 4))

        self.net1.apply(self.__init_weights)
        self.loss1 = nn.MSELoss(reduction='sum')
        self.optimizer1 = torch.optim.Adam(self.net1.parameters(), lr=0.001)
        self.net1.to(torch.float64)
        # self.net1.to('cuda')

    def __init_weights(self, m):
        # initialize the weights of the neural network
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, std=0.01)

    def train(self):
        # get hyper parameters
        max_episodes = self.max_episodes
        max_actions = self.max_actions
        gamma = self.gamma
        exploration_rate = self.exploration_rate
        exploration_decay = self.exploration_decay
        env = self.env

        for i in range(max_episodes):
            for j in range(max_actions):
                # get the action
                if np.random.rand() < exploration_rate:
                    action = np.random.choice([1, 2, 3, 4], 1)
                else:
                    self.last_state = self.last_state.to(torch.double)
                    action = torch.argmax(self.net1(self.last_state)).item() + 1
                # get the next state
                state, reward = env.step([action])
                state = torch.tensor(state, dtype=torch.double)
                # update the neural network
                self.optimizer1.zero_grad()

                target = reward[0] + gamma * torch.max(self.net1(state))
                loss = self.loss1(self.net1(self.last_state)[action - 1], target)
                loss.backward()
                self.optimizer1.step()
                # update the last state
                self.last_state = state
                # check if the game is over
                if j == 49:
                    print('\ndone No.', i + 1, '/', max_episodes, flush=True, end='\n')

    def test(self):
        # test the agent
        env = self.env
        state = env.reset()
        for i in range(100):
            state, _ = env.step([torch.argmax(self.net1(torch.tensor(state))).item() + 1])
            plt.imshow(state.reshape(6, 8))
            plt.show()
            display.display(plt.gcf())
            display.clear_output(wait=True)






In [52]:
config = {
    "n_player": 1,
    "board_width": 8,
    "board_height": 6,
    "n_beans": 5,
    "max_step": 50,
}

thisgame = Snake(config)
state = thisgame.reset()
agent = SnakeAgent(thisgame)
agent.train()

Now please use env.reset() to reset the game

done No. 1 / 200

done No. 2 / 200

done No. 3 / 200

done No. 4 / 200

done No. 5 / 200

done No. 6 / 200

done No. 7 / 200

done No. 8 / 200

done No. 9 / 200

done No. 10 / 200

done No. 11 / 200

done No. 12 / 200

done No. 13 / 200

done No. 14 / 200

done No. 15 / 200

done No. 16 / 200

done No. 17 / 200

done No. 18 / 200

done No. 19 / 200

done No. 20 / 200

done No. 21 / 200

done No. 22 / 200

done No. 23 / 200

done No. 24 / 200

done No. 25 / 200

done No. 26 / 200

done No. 27 / 200

done No. 28 / 200

done No. 29 / 200

done No. 30 / 200

done No. 31 / 200

done No. 32 / 200

done No. 33 / 200

done No. 34 / 200

done No. 35 / 200

done No. 36 / 200

done No. 37 / 200

done No. 38 / 200

done No. 39 / 200

done No. 40 / 200

done No. 41 / 200

done No. 42 / 200

done No. 43 / 200

done No. 44 / 200

done No. 45 / 200

done No. 46 / 200

done No. 47 / 200

done No. 48 / 200

done No. 49 / 200

done No. 50 / 200

done No. 51 /

In [53]:
agent.test()

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

In [13]:
import gymnasium as gym
import torch
import numpy as np
import torch.nn as nn
import random
import matplotlib.pyplot as plt
from IPython import display
%matplotlib inline

env = gym.make("Acrobot-v1", render_mode='rgb_array')
observation, info = env.reset(seed=42)
# for _ in range(1):
#     action = env.action_space.sample()  # this is where you would insert your policy
#     observation, reward, terminated, truncated, info = env.step(action)
#
#     if terminated or truncated:
#         observation, info = env.reset()
# env.close()
action = env.action_space
state = env.observation_space

"""
the general agent shoudl be able to:
1. be able to analysis the env and understand the state space and action space


having the following functions:
1. defineNet(self)
2. trainFromExperience(self, filename)
3. generateExperience(self)
4. trainFromSimulation(self)

"""


class Agent():
    def __init__(self, env, epsilon=1, epsilon_decay=0.995, epsilon_min=0.01, batch_size=32, discount_factor=0.9,
                 num_of_episodes=500):
        # hyper parameters
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size
        self.discount_factor = discount_factor
        self.num_of_episodes = num_of_episodes
        self.env = env

        # get the shape of the state and action
        self.state_shape = env.observation_space.shape
        self.action_shape = env.action_space.shape

        # define the model
        self.__buildd_model()
        pass

    def __buildd_model(self):
        self.net = torch.nn.Sequential(
            nn.Linear(self.state_shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, self.env.action_space.n)
        )
        self.loss = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.001)
        self.net.apply(self._init_weights)
        pass

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=1.0)

    def generateExperience(self, num_experiences):

        experiences = []
        state, _ = self.env.reset()

        for i in range(num_experiences):
            # dont convert the data to other type, keep it as its original type and handle them when using them
            # this strategy can reduce confusion

            last_state = state
            action = self.env.action_space.sample()
            state, reward, done, truncated, info = env.step(action)

            if type(state) == tuple:
                print("state is tuple")

            if done:
                # no need to know the order of the experiences
                state, _ = env.reset()
            else:
                experiences.append([last_state, action, reward, state, done])

        torch.save(experiences, 'data/experiences_test.pt')
        pass

    def replay(self, episodes=10, batch_size=50):
        memory = torch.load('data/experiences_test.pt')
        for i in range(episodes):
            batch = random.sample(memory, self.batch_size)

            for state, action, reward, next_state, done in batch:
                print(type(state))
                prediction = self.net(torch.tensor(state))
                target = prediction.clone()
                if not done:
                    target[action] = reward + self.discount_factor * torch.max(self.net(torch.tensor(next_state)))

                self.optimizer.zero_grad()
                self.loss(prediction, target).backward()
                self.optimizer.step()
        torch.save(self.net, 'data/model_test.pt')

    def test(self):

        pass


test = Agent(env)
test.generateExperience(10000)

experiences = torch.load('data/experiences_test.pt')
test.replay(episodes=1000, batch_size=50)

model = torch.load('data/model_test.pt')
model.eval()
state, _ = env.reset()
frames = []

import imageio

for i in range(100):
    action = model(torch.tensor(state)).argmax()
    state, reward, done, truncated, info = env.step(action)

    if done:
        state, _ = env.reset()
    img = env.render()
    frames.append(img)

env.close()

imageio.mimsave('test.gif', frames, duration=1 / 60)




<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.nd

KeyboardInterrupt: 