In [1]:
%matplotlib notebook

In [2]:
import os
os.chdir('..')

In [3]:
import pyvirtualdisplay
import imageio 
import base64
import IPython


from acme import EnvironmentLoop
from acme.tf import networks
from acme.wrappers import gym_wrapper
from acme import specs
from acme.agents.tf import dqn
from acme.utils.loggers.tf_summary import TFSummaryLogger
import trfl

import sonnet as snt
import tensorflow as tf

from utils import sonnet_resnet

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

In [4]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [5]:
from utils.gym_env import MinerEnv

def display_video(frames, filename='temp.mp4'):
  """Save and display video."""
  # Write video
  with imageio.get_writer(filename, fps=60) as video:
    for frame in frames:
      video.append_data(frame)
  # Read video and display the video
  video = open(filename, 'rb').read()
  b64_video = base64.b64encode(video)
  video_tag = ('<video  width="320" height="240" controls alt="test" '
               'src="data:video/mp4;base64,{0}">').format(b64_video.decode())
  return IPython.display.HTML(video_tag)

In [6]:
env = gym_wrapper.GymWrapper(MinerEnv())
environment_spec = specs.make_environment_spec(env)

base_net = snt.Sequential([
    sonnet_resnet.ResNetTorso(num_output_hidden=(environment_spec.actions.num_values * 2, )),
    networks.LayerNormMLP(layer_sizes=(environment_spec.actions.num_values,))
])

epsilon = tf.Variable(1., trainable=False)
rl = tf.Variable(0.00025, trainable=False)

policy_modules = [
    base_net,
    lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
    lambda x: tf.cast(x, tf.int64)
]

policy_network = snt.Sequential(policy_modules)

In [7]:
%%capture
agent = dqn.DQN(
    environment_spec=environment_spec,
    network=base_net,
    policy_network=policy_network,
    batch_size=10000,
    # prefetch_size=8000,
    target_update_period=10,
    # samples_per_insert=8000,
    # min_replay_size=5,
    # max_replay_size=1000000,
    importance_sampling_exponent=0.2,
    priority_exponent=0.6,
    n_step=4,
#     epsilon=0.3,
    learning_rate=rl,
    discount=0.99,
    logger=TFSummaryLogger('models/11_rl_001_e_3_nstep_4_bs10k_he/logs'),
    checkpoint=True,
    checkpoint_subpath='models/11_rl_001_e_3_nstep_4_bs10k_he',)





In [8]:
loop = EnvironmentLoop(env, agent)

In [None]:
%%capture

steps = 10
start_eps = 0.3
end_eps = 0.3

start_rl = 0.001
end_rl = 0.001

for step in range(steps):

    epsilon.assign(start_eps - step * (start_eps - end_eps) / steps)
    rl.assign(start_rl - step * (start_rl - end_rl) / steps)
    
    loop.run(num_episodes=int(1000))

In [None]:
%%capture
# epsilon.assign(0.1)
# rl.assign(0.0001)
# loop.run(num_episodes=int(10000))

In [None]:
test_env = gym_wrapper.GymWrapper(MinerEnv())
import numpy as np
frames = []
num_steps = 100
timestep = test_env.reset()

while not timestep.last():
    frames.append(test_env.environment.render(mode='rgb_array'))
    action = agent.select_action(timestep.observation)
    timestep = test_env.step(int(action))

display_video(np.array(frames))
