In [1]:
from obstacle_tower_env import ObstacleTowerEnv

import numpy as np
from collections import deque

%matplotlib inline
from matplotlib import pyplot as plt

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
from Source.Agents import DQN_Agent
from Source.Buffer import Buffer

def ProcessState(state):
    return np.rollaxis(np.array([state]), 3, 1)

def ProcessAction(action):
    return [action[:3], action[3:6], action[6:8], action[8:]]

In [3]:
env = ObstacleTowerEnv('./ObstacleTower/obstacletower', retro=False, realtime_mode=False)

INFO:mlagents_envs:
'ObstacleTower-v2.2' started successfully!
Unity Academy name: ObstacleTower-v2.2
        Number of Brains: 1
        Number of Training Brains : 1
        Reset Parameters :
		starting-floor -> 0.0
		visual-theme -> 1.0
		allowed-rooms -> 2.0
		default-theme -> 0.0
		allowed-floors -> 2.0
		agent-perspective -> 1.0
		lighting-type -> 1.0
		dense-reward -> 1.0
		allowed-modules -> 2.0
		tower-seed -> -1.0
		total-floors -> 100.0
Unity brain name: LearningBrain
        Number of Visual Observations (per agent): 1
        Vector Observation space size (per agent): 8
        Number of stacked Vector Observation: 1
        Vector Action space type: discrete
        Vector Action space size (per agent): [3, 3, 2, 3]
        Vector Action descriptions: Movement Forward/Back, Camera, Jump, Movement Left/Right
INFO:gym_unity:1 agents within environment.


In [4]:
action_size = env.action_space.nvec.tolist()
state_size  = list(env.observation_space[0].shape)

buffer = Buffer(buffer_size=2e4, batch_size=128)
agent  = DQN_Agent(state_size=state_size, action_size=action_size, buffer=buffer, device=device, n_iter=8) 

In [None]:
DISPLAY_EVERY = 10
EPOCHS = 100

rewards = deque(maxlen=DISPLAY_EVERY)

for epoch in range(EPOCHS):
    env_info = env.reset()
    state = ProcessState(env_info[0])
    curr_level = env_info[3]
    curr_time  = env_info[2]
    local_buffer = []
    time, done = curr_time, False
    acc_reward = -0.1
    
    # Play a level
    agent.step_begin()
    while not done and time > 0:
        # Act
        action = agent.act(state=state)
        env_info = env.step(action[0])
        next_state, level, time, done = ProcessState(env_info[0][0]), env_info[-1]['current_floor'], env_info[-1]['time_remaining'], env_info[2]
        local_buffer.append( (state, action, next_state, done) )
        
        # If the player reaches a new level, compute the reward and
        # dump all states to the buffer
        if level > curr_level:
            reward = 1 + level - max(0, (curr_time - max(0,time)) / curr_time)
            curr_level, curr_time = level, max(0,time)
            acc_reward += reward
            
            for state, action, next_state, done in local_buffer:
                buffer.add(state=state, action=action, reward=reward, next_state=next_state, done=done)
            del local_buffer
            local_buffer = []

        if done:
            break
        else:
            state = next_state
    
    rewards.append(acc_reward)
        
    # Save frames to global buffer
    for state, action, next_state, done in local_buffer:
        buffer.add(state=state, action=action, reward=-0.1, next_state=next_state, done=done)
    del local_buffer
    
    # Train agent
    agent.step_update()
        
    agent.save(prefix="Test")
    print("[{}/{}] {:0.4f}         ".format(epoch+1, EPOCHS, acc_reward), end="\r")
    
    if (epoch+1) % DISPLAY_EVERY == 0:
        print("[{}/{}] Mean: {:0.4f}         ".format(epoch+1, EPOCHS, np.mean(rewards)))
        
    
    agent.step_end()

[5/100] -0.1000         

In [None]:
torch.cuda.empty_cache()
env.close()

In [None]:
env.close()

In [None]:
# Show an iteration on screen
if not env.realtime_mode:
    env.close()
    env = ObstacleTowerEnv('./ObstacleTower/obstacletower', retro=False, realtime_mode=True)
time, done = curr_time, False
while not done and time > 0:
    # Act
    action = agent.act(state=state)
    env_info = env.step(action[0])
    state, level, time, done = ProcessState(env_info[0][0]), env_info[-1]['current_floor'], env_info[-1]['time_remaining'], env_info[2]