In [None]:
import torch
import torch.nn.functional as F
from torch.autograd import Variable

from a3c.envs import create_unreal_env
from a3c.model import ActorCritic
from pspnet.utils import color_class_image

import matplotlib.pyplot as plt
import numpy as np
from IPython import display

In [None]:
%matplotlib inline

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
env = create_unreal_env(0)

model = ActorCritic(env.observation_space.shape[0], env.action_space)
model.eval()
model.load_state_dict(torch.load('checkpoints/best.pth'))

In [None]:
cx = Variable(torch.zeros(1, 256), volatile=True)
hx = Variable(torch.zeros(1, 256), volatile=True)

reward_sum = 0
episode_length = 0
done = False

state = env.reset()
fig, axes = plt.subplots(1, 2, figsize=(15, 8))
segmented_image = color_class_image(np.argmax(state.transpose(1, 2, 0), axis=2))
segmented = axes[0].imshow(segmented_image)
original = axes[1].imshow(env.render())
axes[0].axis('off')
axes[1].axis('off')
display.display(plt.gcf())
display.clear_output(wait=True)

while not done:
    episode_length += 1
    cx = Variable(cx.data, volatile=True)
    hx = Variable(hx.data, volatile=True)

    value, logit, (hx, cx) = model((Variable(
        torch.from_numpy(state).unsqueeze(0), volatile=True), (hx, cx)))
    prob = F.softmax(logit, dim=1)
    action = prob.max(1, keepdim=True)[1].data.numpy()

    state, reward, done, info = env.step(action[0, 0])
    reward_sum += reward
    
    segmented_image = color_class_image(np.argmax(state.transpose(1, 2, 0), axis=2))
    segmented.set_data(segmented_image)
    original.set_data(env.render())
    axes[0].axis('off')
    axes[1].axis('off')
    axes[0].set_title('step {:3} | episode reward: {} | distance moved: {:3.2f} m'.format(
        episode_length, reward_sum, info['max_distance']))
    display.display(plt.gcf())
    display.clear_output(wait=True)