In [3]:
import numpy as np
import time
import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from baselines.common.atari_wrappers import make_atari, wrap_deepmind



# Using no frameskip version, as it will be then made into frameskip (n=4) by the make_atari function
env = make_atari("SpaceInvadersNoFrameskip-v4")

# Reduce frame size and stacks 4 of them
env = wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=True, scale=True)
    # Clip_rewards: rewards are transformed in the [-1,1] format
    # Episodic_life: considers losing a life as the end of an episode but only reset on true game over. 
        # Used by deepmind as it helps value estimation

env.seed(42)

print("Environment input shape: {}".format(env.observation_space.shape))
print("Environment output shape: {}".format(env.action_space.n))

num_possible_actions = env.action_space.n


Environment input shape: (84, 84, 4)
Environment output shape: 6


In [8]:
#Baseline testing
history = []
tot_ep_reward = 0

for episode in range(1,1000): 
    env.reset()
    tot_ep_reward = 0
    
    for timestep in range(1, 25000):
    
        # env.render(); 
        
        action = np.random.choice(num_possible_actions)

        next_state, reward, done, _ = env.step(action)
        next_state = np.asarray(next_state)

        tot_ep_reward += reward

        cur_state = next_state
    
        if done == True:
            history.append(tot_ep_reward)
            break
    
#calculate reward of last 100 episodes
mean_reward = np.mean(history)

print("mean reward is: {}".format(mean_reward))

mean reward is: 2.696969696969697


In [10]:
print(history)

[6.0, 0.0, 0.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 1.0, 3.0, 11.0, 3.0, 1.0, 2.0, 13.0, 1.0, 4.0, 1.0, 3.0, 1.0, 6.0, 0.0, 3.0, 4.0, 1.0, 0.0, 2.0, 0.0, 1.0, 2.0, 3.0, 2.0, 0.0, 2.0, 0.0, 2.0, 4.0, 2.0, 2.0, 4.0, 4.0, 2.0, 2.0, 4.0, 0.0, 2.0, 3.0, 0.0, 6.0, 3.0, 3.0, 1.0, 4.0, 1.0, 3.0, 4.0, 1.0, 3.0, 1.0, 5.0, 3.0, 6.0, 3.0, 1.0, 7.0, 4.0, 4.0, 1.0, 3.0, 1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 4.0, 5.0, 4.0, 3.0, 3.0, 2.0, 2.0, 8.0, 3.0, 1.0, 7.0, 1.0, 0.0, 5.0, 2.0, 3.0, 5.0, 1.0, 1.0, 1.0, 4.0, 2.0]
