In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
import imageio

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# To get smooth animations
import matplotlib.animation as animation
mpl.rc('animation', html='jshtml')

In [2]:
from tf_agents.environments import suite_atari
from tf_agents.environments.atari_preprocessing import AtariPreprocessing
from tf_agents.environments.atari_wrappers import FrameStack4
from tf_agents.environments.tf_py_environment import TFPyEnvironment

max_episode_steps = 27000 
environment_name = "BreakoutNoFrameskip-v4"

class AtariPreprocessingWithAutoFire(AtariPreprocessing):
    def reset(self, **kwargs):
        obs = super().reset(**kwargs)
        super().step(1) # FIRE to start
        return obs
    def step(self, action):
        lives_before_action = self.ale.lives()
        obs, rewards, done, info = super().step(action)
        if self.ale.lives() < lives_before_action and not done:
            super().step(1) # FIRE to start after life lost
        return obs, rewards, done, info

In [3]:
env_vis = suite_atari.load(
    environment_name,
    max_episode_steps=None,
    gym_env_wrappers=[AtariPreprocessingWithAutoFire, FrameStack4])

tf_env_vis = TFPyEnvironment(env_vis)

A.L.E: Arcade Learning Environment (version +978d2ce)
[Powered by Stella]


In [20]:
saved_policy = tf.saved_model.load("modelos/ddqn_1/ddqn_500000")

In [21]:
def update_scene(num, frames, patch):
    patch.set_data(frames[num])
    return patch,

def plot_animation(frames, repeat=False, interval=40):
    fig = plt.figure()
    patch = plt.imshow(frames[0])
    plt.axis('off')
    anim = animation.FuncAnimation(
        fig, update_scene, fargs=(frames, patch),
        frames=len(frames), repeat=repeat, interval=interval)
    plt.close()
    return anim

frames = []
def save_frames(trajectory):
    global frames
    frames.append(tf_env_vis.pyenv.envs[0].render(mode="rgb_array"))
    
from tf_agents.drivers.dynamic_episode_driver import DynamicEpisodeDriver

watch_driver = DynamicEpisodeDriver(
    tf_env_vis,
    saved_policy,
    observers=[save_frames],
    num_episodes=1)
final_time_step, final_policy_state = watch_driver.run()

plot_animation(frames)

In [6]:
# Guardar como GIF

def guardar_video(filename, frames):
    with imageio.get_writer(filename, fps=60) as video:
        for frame in frames:
            video.append_data(frame)
    


In [17]:
guardar_video("videos/ddqn_500000.mp4", frames)

