In [None]:
import gym
import numpy as np
import copy, random, time, subprocess, os
from tensorflow.keras import layers, models

In [None]:
class QValue:
  def __init__(self):
    self.model = None

  def get_action(self, state):
    states = []
    actions = []
    for a in range(5):
      states.append(np.array(state))
      action_onehot = np.zeros(5)
      action_onehot[a] = 1
      actions.append(action_onehot)
  
    q_values = self.model.predict([np.array(states), np.array(actions)])
    optimal_action = np.argmax(q_values)
    return optimal_action, q_values[optimal_action][0]

In [None]:
def join_frames(o0, o1):
    return np.r_[o0.transpose(), o1.transpose()].transpose() 

In [None]:
q_value = QValue()

In [None]:
import datetime 
import imageio

def create_gif(checkpoint, model='model01', epsilon=0):

    BUCKET = 'gs://etsuji-car-racing-v2-{}'.format(model)
    filename = 'car-racing-v2-{}-{}.hd5'.format(model, checkpoint)
    subprocess.run(['gsutil', 'cp', '{}/{}/{}'.format(BUCKET, model, filename), './'])
    print('load model {}'.format(filename))
    q_value.model = models.load_model(filename)
    os.remove(filename)

    env = gym.make("CarRacing-v2", continuous=False)
    o0 = env.reset()
    o1 = copy.deepcopy(o0)
    done = 0
    total_r = 0
    c = 0

    frames = []

    while not done:
        if c % 1 == 0:
            if np.random.random() < epsilon:
                a = np.random.randint(5)
            else:
                a, _ = q_value.get_action(join_frames(o0, o1))          
        o_new, r, done, i = env.step(a)
        total_r += r
        o0, o1 = o1, o_new 
        c += 1
        frame = env.render('rgb_array')
        frames.append(frame)
        if c % 30 == 0:
            print('{}:{}'.format(a, int(total_r)), end=', ')
    print('{}:{}'.format(a, int(total_r)))
        
    now = datetime.datetime.now()
    imageio.mimsave('car-racing-v2-{}-{:05d}-{}-{}.gif'.format(
        model, int(checkpoint), int(total_r), now.strftime('%Y%m%d-%H%M%S')),
                    frames, 'GIF' , **{'duration': 1.0/30.0})

In [None]:
MODEL_NAME = 'model04'
checkpoints = [432]

for c in checkpoints:    
    create_gif(c, MODEL_NAME, epsilon=0)