In [1]:
import dataclasses
import functools

import jax
from jax import numpy as jnp
import numpy as np

from flax import linen
from flax import struct
import optax
from tensorflow_probability.substrates import jax as tfp

tfb = tfp.bijectors
tfd = tfp.distributions

import gym

import daves_rl_lib
from daves_rl_lib import networks
from daves_rl_lib.environments import environment_lib
from daves_rl_lib.algorithms import deep_q_network

In [5]:
import pyvirtualdisplay
import cv2

_display = pyvirtualdisplay.Display(visible=False,  # use False with Xvfb
                                    size=(600, 400))
_ = _display.start()

In [2]:
buffer_size = 64
epsilon = 0.1
discount_factor = 0.6
learning_rate = 0.1

In [3]:
env = environment_lib.GymEnvironment(
    gym.make("CartPole-v1"), discount_factor=discount_factor)
qvalue_net = networks.make_model([env.action_space.num_actions],
                                 obs_size=env.observation_size)
qvalue_optimizer = optax.adam(learning_rate)
learner = deep_q_network.initialize_learner(
    env=env,
    qvalue_net=qvalue_net,
    qvalue_optimizer=qvalue_optimizer,
    buffer_size=buffer_size,
    batch_size=None,
    seed=jax.random.PRNGKey(0))

step_learner = deep_q_network.compile_deep_q_update_step_stateful(
    env=env,
    qvalue_net=qvalue_net,
    qvalue_optimizer=qvalue_optimizer,
    gradient_batch_size=8,
    target_weights_decay=0.9,
    epsilon=epsilon)
initial_obs = learner.agent_states.observation

In [6]:
actions = []
done = []
observations = []
images = []
for step in range(256):
    learner = step_learner(learner)
    actions.append(learner.last_action)
    done.append(learner.agent_states.done)
    observations.append(learner.agent_states.observation)
    images.append(env._gym_env.render(mode="rgb_array"))

In [7]:
print("Estimated value", qvalue_net.apply(learner.qvalue_weights, initial_obs))
print("Mean reward", np.mean(
            learner.replay_buffer.valid_transitions().next_state.reward))

Estimated value [2.2625456 2.5740664]
Mean reward 1.0


In [8]:
height, width = images[0].shape[:-1]
out = cv2.VideoWriter('/tmp/cartpole2.avi',cv2.VideoWriter_fourcc('X', 'V', 'I', 'D'), 15, frameSize=(width, height))

In [9]:
for i in range(len(images)):
    out.write(images[i])
out.release()

In [17]:
from IPython import display

In [20]:
display.Video(filename='/tmp/cartpole2.avi')

In [23]:
import os
save_path = '/tmp/cartpole2.avi'
compressed_path = "/tmp/result_compressed.mp4"

os.system(f"ffmpeg -i {save_path} -vcodec libx264 {compressed_path}")

ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers
  built with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --e

0

In [24]:
from IPython.display import HTML
from base64 import b64encode
 
def show_video(video_path, video_width = 600):
   
  video_file = open(video_path, "r+b").read()
 
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"""<video width={video_width} controls><source src="{video_url}"></video>""")
 
show_video(compressed_path)