In [1]:
def code_to_move(code):
    if code==0:
        return "left"
    elif code==1:
        return "up"
    elif code==2:
        return "right"
    elif code==3:
        return "down"
    else:
        return code

In [2]:
from utils import build_model
from replay import Episode, ReplyBuffer
import numpy as np
from mazemap import Action, MazeMap, Mode
import tensorflowjs as tfjs

maze_test = np.array([
    [ 0., 1., 0., 0., 0., 0., 0., 0. ],
    [ 0., 0., 0., 1., 1., 0., 1., 0. ],
    [ 1., 1., 1., 0., 0., 0., 1., 0. ],
    [ 0., 0., 0., 0., 1., 1., 0., 0. ],
    [ 0., 1., 1., 1., 0., 0., 0., 1. ],
    [ 0., 1., 0., 0., 0., 0., 0., 1. ],
    [ 0., 0., 0., 1., 0., 0., 0., 1. ],
    [ 0., 0., 0., 1., 0., 0., 0., 0. ],
])

maze_test = np.zeros((8,8))

def start_train(model,
                maze: MazeMap, 
                num_epoch = 15000, 
                max_buffer = 1000, 
                sample_size = 50,
                gamma = 0.9,
                load_path = None,
                save_path = None):
    global epsilon

    if save_path == None:
        save_path = 'maze_model'

    if load_path != None:
        print(f'Load weight from {load_path}')
        model.load_weights(load_path)

    maze_map = maze

    replay_buf: ReplyBuffer = ReplyBuffer(model, maze_map.get_state_size(), max_buffer, gamma)

    history = []
    loss = 0.0
    hsize = maze.get_state_size() // 2
    
    print("Initialization complete, begin training")
    # Run training epoch
    for epoch in range(num_epoch):
        loss = 0.
        is_over = False

        curr_state = maze.observe()
        print(curr_state.shape)
        num_episode = 0

        while not is_over:
            valid_actions = maze.get_valid_actions()
            #print("valid_actions:", valid_actions)
            if len(valid_actions) == 0:
                break

            # Explore
            action = np.random.choice(valid_actions)
            if np.random.rand() > epsilon:
                # Exploit
                action = np.argmax(replay_buf.predict(curr_state))
            action = Action(action)
            print("Old loc:",maze.curr_loc)
            prev_state = curr_state
            curr_state, reward, mode = maze.act(action)
            mode = Mode(mode)
            print("New loc:",maze.curr_loc)

            print("chosen action:",code_to_move(action),"\treward:",reward)
            maze.print_maze()
            print()
            print(mode)
            if mode == Mode.END:
                history.append(1)
                is_over = True
            elif mode == Mode.TERMINATED:
                history.append(0)
                is_over = True
            else:
                is_over = False

            episode = Episode(prev_state, curr_state, action, reward, mode)
            replay_buf.log(episode)
            num_episode += 1

            inputs, outputs = replay_buf.sampling(sample_size)
            train_history = model.fit(inputs, outputs, epochs=8, batch_size=16, verbose=0)
            loss = train_history.history['loss'][-1]
        
        win_rate = 0.0 if len(history) < hsize else np.sum(np.array(history[-hsize:])) / hsize

        print(f'Epoch {epoch}/{num_epoch} | Loss: {loss:.2f} | Episodes: {num_episode} | Win Count: {np.sum(np.array(history))} | Win Rate: {win_rate}')

        if win_rate > 0.9:
            epsilon = 0.05
        
        if win_rate == 1.0:
            print('Reach 100% win rate')
            break

        if epoch % 15 == 0:
            h5file = save_path + ".h5"
            model.save_weights(h5file, overwrite=True)
            tfjs.converters.save_keras_model(model, './')
            
            print(f'Saved model in {save_path}')


    h5file = save_path + ".h5"
    model.save_weights(h5file, overwrite=True)        
    tfjs.converters.save_keras_model(model, './')
    print(f'Saved model in {save_path}')




# This hyperparamter is used to control the ratio of exploration and exploitation
epsilon = 0.1
maze_map = MazeMap(maze_test)
model = build_model(maze_test)
start_train(model, maze_map, 300, 8 * maze_map.get_state_size())


Initialization complete, begin training
(1, 64)
Old loc: (0, 0)
New loc: (0, 0)
chosen action: Action.UP 	reward: -10
░░░░░░░░░░░░░░░░░░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.INVALID
Old loc: (0, 0)
New loc: (1, 0)
chosen action: Action.DOWN 	reward: -0.5
░░░░░░░░░░░░░░░░░░
░                ░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VALID
Old loc: (1, 0)
New loc: (2, 0)
chosen action: Action.DOWN 	reward: -0.5
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VALID
Old loc: (2, 0)
New loc: (3, 0)
chosen action: Action.DOWN 	reward: -0.5
░░░░░░░░░░░░░░░░░░
░                ░
░                

  return h5py.File(h5file)


Old loc: (7, 7)
New loc: (7, 7)
chosen action: Action.RIGHT 	reward: -10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.INVALID
Old loc: (7, 7)
New loc: (7, 7)
chosen action: Action.RIGHT 	reward: -10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.INVALID
Old loc: (7, 7)
New loc: (7, 6)
chosen action: Action.LEFT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░            mmEE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (7, 6)
New loc: (7, 6)
chosen action: Action.DOWN 	reward: -10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░   

Old loc: (0, 7)
New loc: (0, 7)
chosen action: Action.UP 	reward: -10
░░░░░░░░░░░░░░░░░░
░              mm░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.INVALID
Old loc: (0, 7)
New loc: (1, 7)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░              mm░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.TERMINATED
Epoch 13/300 | Loss: 10.01 | Episodes: 3 | Win Count: 4 | Win Rate: 0.0
(1, 64)
Old loc: (1, 7)
New loc: (1, 6)
chosen action: Action.LEFT 	reward: -0.5
░░░░░░░░░░░░░░░░░░
░                ░
░            mm  ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.TERMINATED
Epoch 14/300 | Loss: 9.08 | Episodes: 1 | Win Count: 4 | Win Rate: 0.0
(1, 64)
Old loc: 

Old loc: (0, 0)
New loc: (0, 0)
chosen action: Action.LEFT 	reward: -10
░░░░░░░░░░░░░░░░░░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.INVALID
Old loc: (0, 0)
New loc: (0, 0)
chosen action: Action.LEFT 	reward: -10
░░░░░░░░░░░░░░░░░░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.INVALID
Old loc: (0, 0)
New loc: (0, 0)
chosen action: Action.LEFT 	reward: -10
░░░░░░░░░░░░░░░░░░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.INVALID
Old loc: (0, 0)
New loc: (0, 0)
chosen action: Action.LEFT 	reward: -10
░░░░░░░░░░░░░░░░░░
░mm              ░
░                ░
░                ░
░                ░
░    

Epoch 49/300 | Loss: 0.31 | Episodes: 1 | Win Count: 7 | Win Rate: 0.09375
(1, 64)
Old loc: (7, 7)
New loc: (7, 6)
chosen action: Action.LEFT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░            mmEE░
░░░░░░░░░░░░░░░░░░

Mode.TERMINATED
Epoch 50/300 | Loss: 0.63 | Episodes: 1 | Win Count: 7 | Win Rate: 0.09375
(1, 64)
Old loc: (7, 6)
New loc: (7, 7)
chosen action: Action.RIGHT 	reward: 10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.END
Epoch 51/300 | Loss: 0.49 | Episodes: 1 | Win Count: 8 | Win Rate: 0.125
(1, 64)
Old loc: (7, 7)
New loc: (7, 6)
chosen action: Action.LEFT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░

Epoch 72/300 | Loss: 0.53 | Episodes: 1 | Win Count: 17 | Win Rate: 0.40625
(1, 64)
Old loc: (7, 6)
New loc: (7, 7)
chosen action: Action.RIGHT 	reward: 10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.END
Epoch 73/300 | Loss: 0.07 | Episodes: 1 | Win Count: 18 | Win Rate: 0.4375
(1, 64)
Old loc: (7, 7)
New loc: (7, 6)
chosen action: Action.LEFT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░            mmEE░
░░░░░░░░░░░░░░░░░░

Mode.TERMINATED
Epoch 74/300 | Loss: 0.06 | Episodes: 1 | Win Count: 18 | Win Rate: 0.4375
(1, 64)
Old loc: (7, 6)
New loc: (7, 7)
chosen action: Action.RIGHT 	reward: 10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░             

Old loc: (6, 7)
New loc: (7, 7)
chosen action: Action.DOWN 	reward: 10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.END
Epoch 94/300 | Loss: 0.01 | Episodes: 2 | Win Count: 30 | Win Rate: 0.53125
(1, 64)
Old loc: (7, 7)
New loc: (7, 6)
chosen action: Action.LEFT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░            mmEE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (7, 6)
New loc: (7, 7)
chosen action: Action.RIGHT 	reward: 10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.END
Epoch 95/300 | Loss: 0.01 | Episodes: 2 | Win Count: 31 | Win Rate: 0.53125
(1, 64)
Old loc: (7, 7

Epoch 106/300 | Loss: 0.02 | Episodes: 2 | Win Count: 42 | Win Rate: 0.75
(1, 64)
Old loc: (7, 7)
New loc: (6, 7)
chosen action: Action.UP 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              mm░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (6, 7)
New loc: (7, 7)
chosen action: Action.DOWN 	reward: 10
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.END
Epoch 107/300 | Loss: 0.02 | Episodes: 2 | Win Count: 43 | Win Rate: 0.75
(1, 64)
Old loc: (7, 7)
New loc: (6, 7)
chosen action: Action.UP 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              mm░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (6, 7)
New

Old loc: (6, 7)
New loc: (6, 6)
chosen action: Action.LEFT 	reward: -0.5
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░            mm  ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VALID
Old loc: (6, 6)
New loc: (6, 5)
chosen action: Action.LEFT 	reward: -0.5
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░          mm    ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VALID
Old loc: (6, 5)
New loc: (5, 5)
chosen action: Action.UP 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░          mm    ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (5, 5)
New loc: (4, 5)
chosen action: Action.UP 	reward: -0.5
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░          

Old loc: (0, 0)
New loc: (1, 0)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (1, 0)
New loc: (0, 0)
chosen action: Action.UP 	reward: -1
░░░░░░░░░░░░░░░░░░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (0, 0)
New loc: (1, 0)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (1, 0)
New loc: (0, 0)
chosen action: Action.UP 	reward: -1
░░░░░░░░░░░░░░░░░░
░mm              ░
░                ░
░                ░
░                ░
░            

Old loc: (2, 2)
New loc: (2, 1)
chosen action: Action.LEFT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░  mm            ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (2, 1)
New loc: (2, 0)
chosen action: Action.LEFT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░mm              ░
░                ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (2, 0)
New loc: (3, 0)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░mm              ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (3, 0)
New loc: (3, 1)
chosen action: Action.RIGHT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░  mm            ░
░       

Old loc: (3, 3)
New loc: (4, 3)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░      mm        ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (4, 3)
New loc: (5, 3)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░      mm        ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (5, 3)
New loc: (6, 3)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░                ░
░                ░
░      mm        ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (6, 3)
New loc: (7, 3)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░        

Old loc: (4, 3)
New loc: (4, 2)
chosen action: Action.LEFT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░    mm          ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (4, 2)
New loc: (3, 2)
chosen action: Action.UP 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░    mm          ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (3, 2)
New loc: (3, 3)
chosen action: Action.RIGHT 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░      mm        ░
░                ░
░                ░
░                ░
░              EE░
░░░░░░░░░░░░░░░░░░

Mode.VISITED
Old loc: (3, 3)
New loc: (4, 3)
chosen action: Action.DOWN 	reward: -1
░░░░░░░░░░░░░░░░░░
░                ░
░                ░
░                ░
░                ░
░      mm 