In [1]:
import pickle
import time

import numpy as np
import tensorflow as tf
import gym
import gym.wrappers as gw

from tf_reinforcement_testcases import models

In [2]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    config = tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
try:
    with open('data/data.pickle', 'rb') as file:
        init_data = pickle.load(file)
except FileNotFoundError:
    init_data = None

In [4]:
cart_pole = 'CartPole-v1'
breakout = 'BreakoutNoFrameskip-v4'
env_name = breakout

In [5]:
env = gym.make(env_name)
if env_name == 'BreakoutNoFrameskip-v4':
    env = gw.FrameStack(
        gw.TimeLimit(
            gw.AtariPreprocessing(env),
            max_episode_steps=10000),
        4)
n_outputs = env.action_space.n
input_shape = env.observation_space.shape

In [6]:
if env_name == 'BreakoutNoFrameskip-v4':
    model = models.get_conv_channels_first(input_shape, n_outputs)
else:
    model = models.get_mlp(input_shape, n_outputs)
model.set_weights(init_data['weights'])

In [7]:
def predict(obs):
    # obs = tf.nest.map_structure(lambda x: tf.convert_to_tensor(x, dtype=tf.float32), obs)
    obs = tf.nest.map_structure(lambda x: tf.expand_dims(x, axis=0), obs)
    Q_values = model(obs)
    return np.argmax(Q_values[0])

In [8]:
for i_episode in range(1):
    rewards = 0
    observation = env.reset()
    for t in range(10000):
        time.sleep(0.01)
        env.render()
        # print(observation)
        # action = env.action_space.sample()
        action = predict(observation)
        observation, reward, done, info = env.step(action)
        rewards += reward
        if done:
            print(f"Episode finished after {t+1} timesteps; Total reward: {rewards}")
            break

Episode finished after 757 timesteps; Total reward: 36.0


In [9]:
env.close()