In [None]:
import gym, importlib, os, sys, warnings, IPython
import tensorflow as tf
import itertools
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

%autosave 240
print(tf.__version__)

sys.path.append('../../embodied_arch/')
import embodied_AC as em
from embodied_misc import ActionPolicyNetwork, ValueNetwork, SensoriumNetworkTemplate
importlib.reload(em)


## suppress annoy verbose tf msgs
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # '3' to block all including error msgs
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

## Cartpole Benchmark Setup

In [None]:
actor = lambda s: ActionPolicyNetwork(s, hSeq=(32,), gamma_reg=1e-1)
value = lambda s: ValueNetwork(s, hSeq=(16,16,8,), gamma_reg=1.)
sensor = lambda st, out_dim: SensoriumNetworkTemplate(st, hSeq=(16,), out_dim=out_dim, gamma_reg=5.)

In [None]:
tf.reset_default_graph()
importlib.reload(em)
env = gym.make('CartPole-v0')
cpac = em.EmbodiedAgentAC(
    name="cp-emb-ac", env_=env,
    space_size = (4,1),latentDim=4,
    alpha_p=1., alpha_v=1e-2, 
    actorNN=actor, sensorium=sensor, valueNN=value
)

print(cpac, cpac.s_size, cpac.a_size)

In [None]:
saver = tf.train.Saver(max_to_keep=1)  #n_epochs = 1000
sess = tf.InteractiveSession() 
cpac.init_graph(sess)

num_episodes = 100
n_epochs = 601

In [None]:
## Verify step + play set up
state = cpac.env.reset()
print(state, cpac.act(state, sess))
# cpac.env.step(cpac.act(state, sess))

cpac.play(sess)
cpac.episode_length()

## Baseline

In [None]:
print('Baselining untrained pnet...')
uplen0 = []
for k in range(num_episodes):
    cpac.play(sess, terminal_reward=0.)
    uplen0.append(cpac.last_total_return)
    if k%20 == 0: print("\rEpisode {}/{}".format(k, num_episodes),end="")
base_perf = np.mean(uplen0)
print("\nCartpole stays up for an average of {} steps".format(base_perf))

In [None]:
st = cpac.env.reset()

## Train

In [None]:
obs = []
for ct in range(1250):
    cpac.play(sess)
    tmp = cpac.pretrainV(sess)
    obs.append(tmp)
    print('\r\tIteration {}: Value loss({})'.format(ct, tmp), end="")
plt.plot(obs)

In [None]:
sns.violinplot(obs)

In [None]:
# Train pnet on cartpole episodes
print('Training...')
saver = tf.train.Saver(max_to_keep=1)
hist = cpac.work(sess, saver, num_epochs=n_epochs)

In [None]:
sns.violinplot(hist)

## Test

In [None]:
# Test pnet!
print('Testing...')
uplen = []
for k in range(num_episodes):
    cpac.play(sess, terminal_reward=0.)
    uplen.append(cpac.last_total_return)
    if k%20 == 0: print("\rEpisode {}/{}".format(k, num_episodes),end="")
trained_perf = np.mean(uplen)
print("\nCartpole stays up for an average of {} steps compared to baseline {} steps".format(trained_perf, base_perf) )

## Evaluate

In [None]:
fig, axs = plt.subplots(2, 1, sharex=True)
sns.violinplot(uplen0, ax = axs[0])
axs[0].set_title('Baseline Episode Lengths')
sns.violinplot(uplen, ax = axs[1])
axs[1].set_title('Trained Episode Lengths')

In [None]:
# buf = []
# last_total_return, d, s = 0, False, cpac.env.reset() 
# while (len(buf) < 1000) and not d:
#     a_t = cpac.act(s, sess) 
#     s1, r, d, *rest = cpac.env.step(a_t)
#     cpac.env.render()
#     buf.append([s, a_t, float(r), s1])
#     last_total_return += float(r)
#     s = s1
#     print("\r\tEpisode Length", len(buf), end="")

In [None]:
# sess.close()