In [None]:
resume = 1
identifier = "qbert_pi_prior_diff"

In [None]:
%matplotlib notebook

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
import time
import socket
import matplotlib as mpl
import matplotlib.cm as cm
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import sys
import os
import cv2

In [None]:

# load hyper parameters from disk and change to --play
username = os.getlogin()

if "gpu" in socket.gethostname():
    base_dir = os.path.join('/home/dsi/', username, 'data/rbi')
else:
    base_dir = os.path.join('/data/', username, 'rbi_atari')
    
# clear arguments
sys.argv = sys.argv[:1]

outdir = os.path.join(base_dir, 'results')
for d in os.listdir(outdir):
    if "%s_exp_%04d_" % (identifier, resume) in d:
        exp_name = d
        break

filename = os.path.join(os.path.join(outdir, exp_name), "args.txt")
with open(filename, 'r') as fp:
    args = fp.read().splitlines()
    
# add some default arguments
if "--learn" in args: args.remove("--learn")
if "--load-last-model" not in args: args.append("--load-last-model")
if "--resume %d" % resume not in args: args.append("--load-last-model")

sys.argv += args

from experiment import Experiment
from logger import logger

exp = Experiment(logger.filename)

sleep_time = 0.01
def set_delay(x):
    sleep_time = x

def gen_function():
    player = exp.demonstrate(params=None)
    for k, step in enumerate(player):
        step['k'] = k
        yield step


def get_frame(data):
    
    print("Get frame")
    
    k = data['k']
    s = data['s']
    v = data['v']
    beta = data['beta']
    score = data['score']
    a = data['a']
    adv= data['adv']
    
    ax2.title.set_text('Score %d' % score)
    ax1.title.set_text('%s' % actions[a])
    
    im_plot.set_array(s)
    
    frame_index.append(k)
    v_data.append(v)
    
    value_plot.set_data(frame_index, v_data)
    ax2.set_xbound(lower=max(0, k - 600), upper=k)
    ax2.set_ybound(lower= -1, upper=40)
    
    ax3.set_ybound(lower=-1, upper=2)
    ax3.set_xticklabels(tuple(["NOOP"] + data['actions']))
    ax4.set_ybound(lower=0, upper=1)
    ax4.set_xticklabels(tuple(["NOOP"] + data['actions']))
    
    norm = mpl.colors.Normalize(vmin=0, vmax=len(adv))
    cmap = cm.hot
    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    
    advsort = adv.argsort().argsort()
    
    for i, b in enumerate(beta_plot):
        b.set_height(beta[i])
        
        b.set_facecolor(m.to_rgba(advsort[i]))
        
    for i, ad in enumerate(adv_plot):
        ad.set_height(adv[i])
        
#     interact(set_delay, x=sleep_time);
    time.sleep(sleep_time)
        
        
def onClick(event):
    global pause
    pause ^= True
        
        
print("Ready for testing!")

In [None]:
pause = False

gen = gen_function()

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex=False, sharey=False, figsize=(8, 8))

data = next(gen)
k = data['k'] #elad
s = data['s']
v = data['v']
a = data['a']
beta = data['beta']
score = data['score']
q = data['q']
adv = data['adv']
actions = data['actions']

v_data = [v]
frame_index = [k]

im_plot = ax1.imshow(s, animated=True)
value_plot, = ax2.plot(frame_index, v_data, lw=2, color='r')
adv_plot = ax3.bar(np.arange(len(adv)), adv)
beta_plot = ax4.bar(np.arange(len(beta)), beta) 

ax2.set_xbound(lower=0.0, upper=600)
ax2.title.set_text('Score %d' % score)

ax2.grid()
ax3.grid()
ax4.grid()

get_frame(next(gen))

fig.canvas.mpl_connect('button_press_event', onClick)
ani = animation.FuncAnimation(fig, get_frame, frames=gen, interval=10, repeat=False)

print("END OF SIMUATION")