In [3]:
import time

from env import Env
from dqn import DQN

In [4]:
# 这里我们定义多个游戏地图，后面可以使用不同的游戏地图观察agent的行为
def game_map_1(environment):
    environment.reset()
    environment.add_item('yellow_star', (3, 3), credit=100, pickable=True)
    environment.add_item('yellow_star', (0, 7), credit=100, pickable=True)
    environment.add_item('red_ball', (5, 6), terminal=True, label="Exit")

In [5]:
# set the environment
env = Env((8, 8), (130, 90), default_rewards=0)

In [6]:
# select a game
game_map_1(env)

In [7]:
for _ in range(1):
    action = env.action_space.sample()
    print(action)
    reward, next, end = env.step(action)
    print(reward, next, end)
    time.sleep(0.2)
    
env.reset()

E
0 (0, 1) False


In [8]:
# 下面两个函数将位置信息转换为状态信息
def location_one_hot(location, map_dimension):
    row, column = location
    total_rows, total_columns = map_dimension
    
    assert row < total_rows and column < total_columns
    
    # 将`行`和`列`合并为一个ID，后面用于`one hot`编码
    location_id = row * total_columns + column
    
    # `one hot`编码
    one_hot = [0] * (total_rows * total_columns)
    one_hot[location_id] = 1
    
    return one_hot

def location_multi_hot(locations, map_dimension):
    total_rows, total_columns = map_dimension
    one_hot = [0] * (total_rows * total_columns)
    
    for loc in locations:
        row, column = loc
        
        assert row < total_rows and column < total_columns
        
        # 将`行`和`列`合并为一个ID用于`one hot`编码
        location_id = row * total_columns + column
        one_hot[location_id] = 1
        
    return one_hot

# 下面函数将环境的全部信息转换成状态信息
def state_from_environment(environment):
    # 环境地图大小
    dimension = (environment.map.n_rows, environment.map.n_columns)

    # agent状态信息
    agent_state = location_one_hot(environment.agent.at, dimension)

    star_locations = []
    exit_location = None
    for item in environment.map.all_items:        
        if item.pickable == True:
            star_locations.append(item.index)
        elif item.terminal == True:
            exit_location = item.index
        else:
            assert False, "Unknown item in the environment"
            
    # 必须给agent设置一个出口
    assert exit_location != None, "You must have a exit point for agent!"

    # 出口Exit状态信息
    exit_state = location_one_hot(exit_location, dimension)

    # 环境中星星的状态信息
    stars_state = location_multi_hot(star_locations, dimension)

    # 返回所有信息的组合
    return agent_state + exit_state + stars_state

In [9]:
# hyperparameters
lr = 0.001
gamma = 0.99

training_episodes = 1

In [10]:
dqn = DQN(env.map.n_squares*3, env.action_space.n_actions, lr, gamma)

In [None]:
# 随机产生经验数据用于后面进行训练

while True:
    env.reset()
    env.show = False
    
    this_episode = []
    state = state_from_environment(env)
    end = False
    while end == False:
        action = env.action_space.sample()
        reward, next_location, end = env.step(action)
        

In [11]:
training_episodes = 1000
total_losses = 0
for episode in range(1, training_episodes+1):
    env.reset()  # 复位环境
    env.show = False
    
    # 此时环境刚复位，获取此时的环境状态信息
    state = state_from_environment(env)

    # 调式信息
    location = env.agent.at

    # 记录此回合agent的经历
    this_episode = []
    
    end = False  # 表明此回合是否结束
    while end == False:
        # 查询DQN由当前状态获取动作
        # 注意：使用DQN时一般将动作编码为从0开始的连续数字，DQN内部以及其输入输出
        # 都使用这种数字代表动作。
        # 环境理解的动作可能不是数字，所以要进行转换。
        action_id = dqn.next_action(state)
        action = env.action_space.action_from_id(action_id)

        # 指导agent走一步，环境返回这一步行动产生的reward，agent的位置和agent是否到达了出口
        reward, next_location, end = env.step(action)

        # 获取agent走了一步后的环境状态信息
        next_state = state_from_environment(env)

        # 记录这一步
        this_episode.append((state, action_id, reward, next_state))

        state = next_state

        # 调式信息
        #print("action {}: {} ----> {} reward {}".format(action, location, next_location, reward))
        #print(next_state)
        location = next_location

    # 训练DQN
    loss = dqn.train_an_episode(this_episode)
    total_losses += loss
    if episode != 0 and episode % 200 == 0:
        print("trained episodes {}: current loss is {:.4f}, average loss is {:.4f}".format(episode, loss, total_losses/episode))

trained episodes 200: current loss is 377.2625, average loss is 149.1892
trained episodes 400: current loss is 222.8826, average loss is 150.8623
trained episodes 600: current loss is 31.8806, average loss is 129.5653
trained episodes 800: current loss is 9.6557, average loss is 113.5371
trained episodes 1000: current loss is 48.7760, average loss is 103.3547
