In [None]:
import numpy as np
import time

from env import Env
from dqn import DQN
from chessboard import Chessboard

In [None]:
# 这里我们定义多个游戏地图，后面可以使用不同的游戏地图观察agent的行为
def game_map_1(environment):
    environment.reset()
    environment.add_item('yellow_star', (3, 3), credit=1000, pickable=True, label=1)
    environment.add_item('yellow_star', (0, 7), credit=1000, pickable=True, label=2)
    environment.add_item('red_ball', (5, 6), terminal=True, label="Exit")

In [None]:
# set the environment
env = Env((8, 8), (160, 90), default_rewards=0)

In [None]:
# select a game
game_map_1(env)

In [None]:
for _ in range(1):
    action = env.action_space.sample()
    print(action)
    reward, next, end, _ = env.step(action)
    print(reward, next, end)
    time.sleep(0.2)
    
env.reset()

In [None]:
def compose_state(agent_loc, items_picked, map_size, tunnel=0):
    # agent状态信息
    agent_state = location_one_hot(agent_loc, map_size)
    
    # agent已经收集的星星
    items_state = location_multi_hot(items_picked, map_size)

    return agent_state + items_state + [tunnel]

def centralize_range(start, stop=0):
    if stop == 0:
        stop = start
        start = 0
        
    assert stop >= start
    
    len = stop - start + 1
    middle = (len-1) / 2
    start = - middle
    stop = middle
    
    return start, stop

def regularize_range(start, stop=0):
    begin, end = centralize_range(start, stop)
    half = (end-begin) / 2
    result = [i/half for i in np.arange(begin, end+1)]
    return result

# test
arg = (3, 6)
start, stop = centralize_range(*arg)
print(start, stop)
numbers = regularize_range(*arg)
print(numbers)

In [None]:
def regularize_location(loc, map_size):
    r, c = loc
    height, width = map_size
    
    rows = regularize_range(height-1)
    #print(rows)
    
    cols = regularize_range(width-1)
    #print(cols)
    
    return rows[r], cols[c]
    
r, c = regularize_location((0, 1), (8, 8))
print("location:", r, c)
    
def compose_state2(agent_loc, encode, map_size):
    agent_row, agent_col = regularize_location(agent_loc, map_size)
    return [agent_row, agent_col, encode]

state = compose_state2((0, 0), 0, (8, 8))
print(state)
state = compose_state2((7, 7), 1, (8, 8))
print(state)

In [None]:
# 下面两个函数将位置信息转换为状态信息
def location_one_hot(location, map_dimension):
    row, column = location
    total_rows, total_columns = map_dimension
    
    assert row < total_rows and column < total_columns
    
    # 将`行`和`列`合并为一个ID，后面用于`one hot`编码
    location_id = row * total_columns + column
    
    # `one hot`编码
    one_hot = [0] * (total_rows * total_columns)
    one_hot[location_id] = 1
    
    return one_hot

def location_multi_hot(locations, map_dimension):
    total_rows, total_columns = map_dimension
    one_hot = [0] * (total_rows * total_columns)
    
    for loc in locations:
        row, column = loc
        
        assert row < total_rows and column < total_columns
        
        # 将`行`和`列`合并为一个ID用于`one hot`编码
        location_id = row * total_columns + column
        one_hot[location_id] = 1
        
    return one_hot

# 下面函数将环境的全部信息转换成状态信息
def state_from_environment_old(environment):    
    # 环境地图大小
    dimension = (environment.map.n_rows, environment.map.n_columns)

    # agent状态信息
    agent_state = location_one_hot(environment.agent.at, dimension)

    star_locations = []
    exit_location = None
    for item in environment.map.all_items:        
        if item.pickable == True:
            star_locations.append(item.index)
        elif item.terminal == True:
            exit_location = item.index
        else:
            assert False, "Unknown item in the environment"
            
    # 必须给agent设置一个出口
    assert exit_location != None, "You must have a exit point for agent!"

    # 出口Exit状态信息
    exit_state = location_one_hot(exit_location, dimension)

    # 环境中星星的状态信息
    stars_state = location_multi_hot(star_locations, dimension)

    n_items = len(environment.agent.bag_of_objects)
    
    item_locations = []
    for item in environment.agent.bag_of_objects:
        item_locations.append(item.index)
    items_state = location_multi_hot(item_locations, dimension)
    
    # 返回所有信息的组合
    #return agent_state + exit_state + stars_state
    return agent_state + items_state + [0]

def state_from_environment(env):    
    # 环境地图大小
    map_size = (env.map.n_rows, env.map.n_columns)
    encode = 0
    for item in env.agent.bag_of_objects:
        item_id = item.label # 编号从1开始
        assert item_id > 0
        
        encode += pow(2, item_id-1)
    
    state = compose_state2(env.agent.at, encode, map_size)
    return state

env.reset()
state = state_from_environment(env)
print(state)

In [None]:
# hyperparameters
lr = 0.001
gamma = 0.9

In [None]:
# env.map.n_squares*2+1
dqn = DQN(3, env.action_space.n_actions, [64, 32], lr, gamma, experience_limit=40000)

In [None]:
# 打印Action-Value信息
def show_action_values(env):
    location = env.agent.at
    state = state_from_environment(env)
    
    state = np.array(state)
    matrix_form = state.reshape((1, *state.shape))
    action_values = dqn.action_values(matrix_form)[0]  
    
    # debug
    #print("begin of iterating...")
    #print(action_values)
    
    text_dict = {}    
    for action_id, value in enumerate(action_values):
        action = env.action_space.action_from_id(action_id)
        value = np.round(value, 2)
        
        # debug
        #print(action, end=" ")
        
        text_dict[action] = str(value)

    # debug
    #print("end of iterating...")
    
    env.draw_text(location, text_dict)
    
def _show_state(state, location):    
    state = np.array(state)
    matrix_form = state.reshape((1, *state.shape))
    action_values = dqn.action_values(matrix_form)[0]  
    
    text_dict = {}    
    for action_id, value in enumerate(action_values):
        action = env.action_space.action_from_id(action_id)
        value = np.round(value, 2)
        text_dict[action] = str(value)

    env.draw_text(location, text_dict)
    
def show_all_state(env, picked_items, tunnel=0):
    map_size = (env.map.n_rows, env.map.n_columns)
    locations = [(row, col) for row in range(map_size[0]) for col in range(map_size[1])]
    for loc in locations:
        state = compose_state(loc, picked_items, map_size, tunnel=tunnel)
        _show_state(state, loc)
        
def show_all_state2(env, items_encode):
    map_size = (env.map.n_rows, env.map.n_columns)
    locations = [(row, col) for row in range(map_size[0]) for col in range(map_size[1])]
    for loc in locations:
        state = compose_state2(loc, items_encode, map_size)
        _show_state(state, loc)

In [None]:
# 随机产生经验数据用于后面进行训练
def sample_experience():
    while True:
        env.reset()
        env.show = False
        location = env.agent.at
        next_location = location
        # 获取初始状态
        state = state_from_environment(env)
        end = False
        while end == False:
            # 随机选取一个动作
            action = env.action_space.sample()
            action_id = env.action_space.action_id(action)
            reward, next_location, end, _ = env.step(action)

            # 获取agent走了一步后的环境状态信息
            next_state = state_from_environment(env)
            dqn.fill_experience((state, action_id, reward, next_state, end))
            state = next_state
            location = next_location

        # 如果经验数据已满则进行预训练
        if dqn.experience.is_full:
            break
        
    
n_pretrains = 30
for pretrain in range(n_pretrains):
    dqn.clear_experience()
    sample_experience()
    for i in range(100):
        batch_size = 400
        loss = dqn.train_batch_states(batch_size)

        print("experience sample {} pretrain batch {}: loss is {}".format(
                                                pretrain,
                                                i, loss))

In [None]:
#dqn.clear_experience()
training_episodes = 50000
total_losses = 0

for episode in range(1, training_episodes+1):
    env.reset()  # 复位环境
    #env.show = True
    env.show = False
        
    # 此时环境刚复位，获取此时的环境状态信息
    state = state_from_environment(env)
    
    # 记录agent走一步前后的两个位置
    location = env.agent.at
    next_location = location
    
    this_episode = []
    hit_walls = 0
    end = False  # 表明此回合是否结束
    while end == False:
        # 打印当前状态的Action-Value值
        #show_action_values(env)
        
        # 查询DQN由当前状态获取动作
        # 注意：使用DQN时一般将动作编码为从0开始的连续数字，DQN内部以及其输入输出
        # 都使用这种数字代表动作。
        # 环境理解的动作可能不是数字，所以要进行转换。
        action_id = dqn.next_action(state, episode)
        action = env.action_space.action_from_id(action_id)

        # 指导agent走一步，环境返回这一步行动产生的reward，agent的新位置和agent是否到达了出口
        reward, next_location, end, _ = env.step(action)
        #time.sleep(0.05)

        # 获取next state，并将这一步的信息记录入经验数据
        next_state = state_from_environment(env)
        one_step = (state, action_id, reward, next_state, end)
        dqn.fill_experience(one_step)
        this_episode.append(one_step)
        
        if location == next_location:
            hit_walls += 1
        
        state = next_state
        location = next_location

        if env.steps >= 400:
            break
            
        # 调式信息
        #print("state:", state)
        #print("common state:", common_state(state))
        #print("step {}: {} ----> {} {} reward {}\n".format(env.steps, location, next_location, action, reward))
        #print(next_state)
            
    batch_size = 100
    loss = dqn.train_batch_states(batch_size)
    total_losses += loss

    # 训练一整条episode
    for _ in range(50):
        dqn.train_an_episode(this_episode)
    
    #print("items remain in map:", len(env.pickable_items()))
    print("# {}: batch avg loss is {:.4f}, agent moved {} steps, hit walls {}, reward is {}".format(
                                            episode, 
                                            loss / batch_size,
                                            env.steps,
                                            hit_walls,
                                            reward))
    #env.show = True
    #show_all_state2(env, 0)
        

In [None]:
dqn.clear_experience()

In [None]:
def test_train(batch):
    for i in range(batch):    
        # 手工合成state和next state
        state = compose_state2((5, 5), 0, (8, 8))
        next_state = compose_state2((5, 6), 0, (8, 8))
        action_id = env.action_space.action_id('E')
        dqn.fill_experience((state, action_id, 1000, next_state, True))

        state = compose_state2((6, 5), 0, (8, 8))
        next_state = compose_state2((5, 5), 0, (8, 8))
        action_id = env.action_space.action_id('N')
        dqn.fill_experience((state, action_id, 0, next_state, False))

        state = compose_state2((6, 6), 0, (8, 8))
        next_state = compose_state2((5, 6), 0, (8, 8))
        action_id = env.action_space.action_id('N')
        dqn.fill_experience((state, action_id, 1000, next_state, True))
         
        state = compose_state2((6, 5), 0, (8, 8))
        next_state = compose_state2((6, 6), 0, (8, 8))
        action_id = env.action_space.action_id('E')
        dqn.fill_experience((state, action_id, 0, next_state, False))
            
    dqn.train_batch_states(batch)
    env.show = True
    show_all_state2(env, 0)
    
#test_train(400)

In [None]:
def test_train_loop(batch):
    for i in range(batch):    
        state = compose_state2((0, 1), 0, (8, 8))
        next_state = compose_state2((0, 1), 0, (8, 8))
        action_id = env.action_space.action_id('N')
        dqn.fill_experience((state, action_id, 0, next_state, True))

    dqn.train_batch_states(batch)
    env.show = True
    show_all_state2(env, 0)
    
#test_train_loop(400) 

In [None]:
state = state_from_environment(env)
print(state)
env.show = True
show_all_state2(env, 0)

In [None]:
# ---- 测试 ----
# 复位环境并获取初始状态
env.reset()
env.show = True

end = False  # 表明此回合是否结束
while end == False:
    # debug
    show_action_values(env)

    # 获取环境状态
    state = state_from_environment(env)
    
    # 从DQN获取Policy并选取具有最大Value值的动作作为下一个动作
    action_id = dqn.best_action(state)
    action = env.action_space.action_from_id(action_id)

    # debug
    #print("action:", action)
    
    # agent执行此动作
    reward, next_location, end, _ = env.step(action)
    time.sleep(0.2)
