<a href="https://colab.research.google.com/github/kuds/rl-atari-breakout/blob/main/%5BAtari%20Breakout%5D%20Model-Based%20Reinforcement%20Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using Model-Based Reinforcement Learning to play Atari's Breakout

## References/Repositories
- [Model-Based Reinforcement Learning for Atari - Simplified Repo](https://github.com/dhruvramani/model-based-atari/tree/master)
- [Model-Based Reinforcement Learning for Atari - Paper](https://arxiv.org/abs/1903.00374)
- [Tensor2Tensor - RL Code](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/rl)

In [1]:
!pip install gymnasium

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0


In [2]:
import gymnasium
import platform
import torch
import numpy
from importlib.metadata import version
from datetime import datetime
import google.colab.drive

# Load the CartPole-v1 environment
env = gymnasium.make("CartPole-v1")

In [3]:
print(f"Python Version: {platform.python_version()}")
print(f"Torch Version: {version('torch')}")
print(f"Is Cuda Available: {torch.cuda.is_available()}")
print(f"Cuda Version: {torch.version.cuda}")
print(f"Gymnasium Version: {version('gymnasium')}")
print(f"Numpy Version: {version('numpy')}")

Python Version: 3.10.12
Torch Version: 2.5.0+cu121
Is Cuda Available: False
Cuda Version: 12.1
Gymnasium Version: 1.0.0
Numpy Version: 1.26.4


In [4]:
def collect_data(env, num_episodes=1000):
    data = []
    for _ in range(num_episodes):
        state = env.reset()
        terminated = False
        truncated = False
        while not (terminated or truncated):
            action = env.action_space.sample()
            next_state, reward, terminated, truncated, info = env.step(action)
            data.append((state, action, reward, next_state))
            if(terminated or truncated):
              print(next_state)
            state = next_state
    return data

# Collect data

data = collect_data(env)
print(len(data))

[ 0.09985067  1.1573845  -0.22733371 -2.1230328 ]
[-0.14263262 -0.937871    0.22696428  1.7066796 ]
[-0.07151423 -0.9587023   0.22259556  1.9680437 ]
[-0.18474822 -0.81536525  0.23248145  1.5020494 ]
[ 0.16199492  0.38389528 -0.22363782 -0.8589244 ]
[-0.06390031 -0.03285903  0.21369837  0.53873193]
[-0.20088229 -0.60771954  0.22065611  1.156212  ]
[-0.22312042 -1.7290013   0.2535453   2.663306  ]
[-0.0717577  -1.165038    0.23382524  2.0948775 ]
[ 0.09235707  0.5676411  -0.22014874 -1.4649719 ]
[ 0.1689128   0.6503425  -0.23379983 -1.3019111 ]
[-0.10049108 -0.5654767   0.22380552  1.2447968 ]
[-0.03197342 -0.45850286  0.21602406  1.3897356 ]
[ 0.19051856  1.5527407  -0.22743213 -2.4447765 ]
[-0.20329696 -0.5940527   0.22163054  1.1417512 ]
[-0.0968428  -0.8284884   0.21895196  1.567352  ]
[-0.1085595  -0.41229263  0.21909222  0.95741194]
[-0.18645835 -1.5370182   0.24279214  2.331681  ]
[-0.00744801 -0.24765879  0.21804842  1.0654683 ]
[-0.20567559 -0.6362208   0.22102894  1.0211744 ]


In [5]:
print(data[0][0])

(array([-0.00487464, -0.02781332, -0.02078416, -0.00240914], dtype=float32), {})


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the neural network for the dynamics model
class DynamicsModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DynamicsModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, state_dim)  # Predict next state
        self.reward = nn.Linear(128, 1)       # Predict reward

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        next_state = self.fc3(x)
        reward = self.reward(x)
        return next_state, reward


In [None]:
# Initialize model and optimizer
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
model = DynamicsModel(state_dim, action_dim)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Convert collected data to tensors for training
states = torch.tensor([d[0] for d in data], dtype=torch.float32)
actions = torch.tensor([d[1] for d in data], dtype=torch.float32).unsqueeze(1)
next_states = torch.tensor([d[3] for d in data], dtype=torch.float32)
rewards = torch.tensor([d[2] for d in data], dtype=torch.float32).unsqueeze(1)

# Train the dynamics model
for epoch in range(100):
    optimizer.zero_grad()
    predicted_next_states, predicted_rewards = model(states, actions)
    loss = criterion(predicted_next_states, next_states) + criterion(predicted_rewards, rewards)
    loss.backward()
    optimizer.step()

print("Model training complete!")


ValueError: expected sequence of length 4 at dim 2 (got 0)

In [None]:
[d[0] for d in data]

[(array([ 0.00603946,  0.03967333,  0.00213784, -0.01986235], dtype=float32),
  {}),
 array([ 0.00683293, -0.15547922,  0.00174059,  0.27349433], dtype=float32),
 array([ 0.00372334, -0.35062596,  0.00721048,  0.56672573], dtype=float32),
 array([-0.00328917, -0.5458483 ,  0.018545  ,  0.8616715 ], dtype=float32),
 array([-0.01420614, -0.35098374,  0.03577843,  0.57487684], dtype=float32),
 array([-0.02122582, -0.15638115,  0.04727596,  0.29367638], dtype=float32),
 array([-0.02435344, -0.35214412,  0.05314949,  0.6008867 ], dtype=float32),
 array([-0.03139632, -0.54796773,  0.06516723,  0.9098259 ], dtype=float32),
 array([-0.04235568, -0.3537854 ,  0.08336374,  0.6383163 ], dtype=float32),
 array([-0.04943138, -0.5499647 ,  0.09613007,  0.95604396], dtype=float32),
 array([-0.06043068, -0.74623895,  0.11525095,  1.2773148 ], dtype=float32),
 array([-0.07535546, -0.9426261 ,  0.14079724,  1.60375   ], dtype=float32),
 array([-0.09420798, -1.1391054 ,  0.17287225,  1.9368104 ], dtype=f

In [None]:
def mpc_action_selection(model, current_state, num_simulations=100, horizon=10):
    best_action = None
    best_reward = -np.inf

    for _ in range(num_simulations):
        simulated_state = current_state
        total_reward = 0
        for _ in range(horizon):
            action = np.random.choice([0, 1])  # Random action sampling for now
            action_tensor = torch.tensor([action], dtype=torch.float32).unsqueeze(0)
            state_tensor = torch.tensor(simulated_state, dtype=torch.float32).unsqueeze(0)
            next_state, reward = model(state_tensor, action_tensor)
            total_reward += reward.item()
            simulated_state = next_state.detach().numpy()[0]

        if total_reward > best_reward:
            best_reward = total_reward
            best_action = action

    return best_action


In [None]:
def evaluate_model_based_agent(env, model, num_episodes=10):
    for episode in range(num_episodes):
        state = env.reset()
        done = False
        total_reward = 0
        while not done:
            action = mpc_action_selection(model, state)
            state, reward, done, _ = env.step(action)
            total_reward += reward
        print(f"Episode {episode + 1}: Total Reward: {total_reward}")

# Evaluate the agent
evaluate_model_based_agent(env, model)


In [None]:
import os
import gym
import numpy as np
from tqdm import tqdm
import tensorflow as tf

from utils import *

g_env_model = None
def cached_world_model(sess, ob_shape, action_dim, config, path):
    global g_env_model
    if g_env_model is None:
        old_val = config.n_envs
        config.n_envs = 1
        g_env_model = EnvModel(ob_shape, action_dim, config)
        save_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='env_model')
        loader = tf.train.Saver(var_list=save_vars)
        loader.restore(sess, path)
        printstar('World Model Restored')
        config.n_envs = old_val

    return g_env_model

def inject_additional_input(layer, inputs, name, mode="multi_additive"):
  """Injects the additional input into the layer.

  Args:
    layer: layer that the input should be injected to.
    inputs: inputs to be injected.
    name: TF scope name.
    mode: how the infor should be added to the layer:
      "concat" concats as additional channels.
      "multiplicative" broadcasts inputs and multiply them to the channels.
      "multi_additive" broadcasts inputs and multiply and add to the channels.

  Returns:
    updated layer.

  Raises:
    ValueError: in case of unknown mode.
  """
  layer_shape = shape_list(layer)
  input_shape = shape_list(inputs)
  zeros_mask = tf.zeros(layer_shape, dtype=tf.float32)
  if mode == "concat":
    emb = common_video.encode_to_shape(inputs, layer_shape, name)
    layer = tf.concat(values=[layer, emb], axis=-1)
  elif mode == "multiplicative":
    filters = layer_shape[-1]
    input_reshaped = tf.reshape(inputs, [-1, 1, 1, input_shape[-1]])
    input_mask = tf.layers.dense(input_reshaped, filters, name=name)
    input_broad = input_mask + zeros_mask
    layer *= input_broad
  elif mode == "multi_additive":
    filters = layer_shape[-1]
    input_reshaped = tf.reshape(inputs, [-1, 1, 1, input_shape[-1]])
    input_mul = tf.layers.dense(input_reshaped, filters, name=name + "_mul")
    layer *= tf.nn.sigmoid(input_mul)
    input_add = tf.layers.dense(input_reshaped, filters, name=name + "_add")
    layer += input_add
  else:
    raise ValueError("Unknown injection mode: %s" % mode)

  return layer

class EnvModel(object):
    def __init__(self, obs_shape, action_dim, config):

        self.obs_shape = obs_shape
        self.action_dim = action_dim
        self.config = config

        self.hidden_size = config.hidden_size
        self.layers = config.n_layers
        self.dropout_p = config.dropout_p

        if(config.activation_fn == 'relu'):
            self.activation_fn = tf.nn.relu
        elif (config.activation_fn == 'tanh'):
            self.activation_fn = tf.nn.tanh

        self.l2_clip = config.l2_clip
        self.softmax_clip = config.softmax_clip
        self.reward_coeff = config.reward_coeff
        self.n_envs = config.n_envs
        self.max_ep_len = config.max_ep_len
        self.log_interval = config.log_interval

        self.is_policy = config.is_policy
        self.has_rewards = config.has_rewards
        self.num_rewards = config.num_rewards

        self.width, self.height, self.depth = self.obs_shape

        self.states_ph = tf.placeholder(tf.float32, [None, self.width, self.height, self.depth])
        self.actions_ph = tf.placeholder(tf.uint8, [None, 1])
        self.actions_oph = tf.one_hot(self.actions_ph, depth=action_dim)
        self.target_states = tf.placeholder(tf.float32, [None, self.width, self.height, self.depth])
        if(self.has_rewards):
            self.target_rewards = tf.placeholder(tf.uint8, [None, self.num_rewards])

        # NOTE - Implement policy and value parts later
        with tf.variable_scope("env_model"):
            self.state_pred, self.reward_pred, _, _ = self.network()

        # NOTE - Change this maybe to video_l2_loss
        self.state_loss = tf.math.maximum(tf.reduce_sum(tf.pow(self.state_pred - self.target_states, 2)), self.l2_clip)
        self.loss = self.state_loss

        if(self.has_rewards):
            self.reward_loss = tf.math.maximum(tf.reduce_mean(tf.losses.softmax_cross_entropy(self.tw_one_hot, self.reward_pred)), self.softmax_clip)
            self.loss = self.loss + (self.reward_coeff * self.reward_loss)

        self.opt = tf.train.AdamOptimizer().minimize(self.loss)

        tf.summary.scalar('loss', self.loss)
        if(self.has_rewards):
            tf.summary.scalar('image_loss', self.state_loss)
            tf.summary.scalar('reward_loss', self.reward_loss)

    def generate_data(self, envs):
        states = envs.reset()
        for frame_idx in range(self.max_ep_len):
            states = states.reshape(self.n_envs, self.width, self.height, self.depth)
            if(self.n_envs == 1):
                actions = envs.action_space.sample()
            else:
                actions = [envs.action_space.sample() for _ in range(self.n_envs)]
            next_states, rewards, dones, _ = envs.step(actions)
            next_states = next_states.reshape(self.n_envs, self.width, self.height, self.depth)

            yield frame_idx, states, actions, rewards, next_states, dones
            states = next_states
            if(self.n_envs == 1 and dones == True):
                states = envs.reset()

    def network(self):
        def middle_network(layer):
            x = layer
            kernel1 = (3, 3)
            filters = shape_list(x)[-1]
            for i in range(2):
              with tf.variable_scope("layer%d" % i):
                y = tf.nn.dropout(x, 1.0 - 0.5)
                y = tf.layers.conv2d(y, filters, kernel1, activation=self.activation_fn,
                                     strides=(1, 1), padding="SAME")
                if i == 0:
                  x = y
                else:
                  x = layer_norm(x + y)
            return x

        batch_size = tf.shape(self.states_ph)[0]

        filters = self.hidden_size
        kernel2 = (4, 4)
        action = self.actions_oph

        # Normalize states
        if(self.n_envs > 1):
            states = [standardize_images(self.states_ph[i, :, :, :]) for i in range(self.n_envs)]
            stacked_states = tf.stack(states)
        else:
            stacked_states = standardize_images(self.states_ph)
        inputs_shape = shape_list(stacked_states)

        # Using non-zero bias initializer below for edge cases of uniform inputs.
        x = tf.layers.dense(
            stacked_states, filters, name="inputs_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = add_timing_signal_nd(x)

        # Down-stride.
        layer_inputs = [x]
        for i in range(self.layers):
          with tf.variable_scope("downstride%d" % i):
            layer_inputs.append(x)
            x = tf.nn.dropout(x, 1.0 - self.dropout_p)
            x = make_even_size(x)
            if i < 2:
              filters *= 2
            x = add_timing_signal_nd(x)
            x = tf.layers.conv2d(x, filters, kernel2, activation=self.activation_fn,
                                 strides=(2, 2), padding="SAME")
            x = layer_norm(x)

        if self.is_policy:
          with tf.variable_scope("policy"):
            x_flat = tf.layers.flatten(x)
            policy_pred = tf.layers.dense(x_flat, self.action_dim)
            value_pred = tf.layers.dense(x_flat, 1)
            value_pred = tf.squeeze(value_pred, axis=-1)
        else:
          policy_pred, value_pred = None, None

        x = inject_additional_input(x, action, "action_enc", "multi_additive")

        # Inject latent if present. Only for stochastic models.
        target_states = standardize_images(self.target_states)

        x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x = middle_network(x)

        # Up-convolve.
        layer_inputs = list(reversed(layer_inputs))
        for i in range(self.layers):
          with tf.variable_scope("upstride%d" % i):
            x = tf.nn.dropout(x, 1.0 - 0.1)
            if i >= self.layers - 2:
              filters //= 2
            x = tf.layers.conv2d_transpose(
                x, filters, kernel2, activation=self.activation_fn,
                strides=(2, 2), padding="SAME")
            y = layer_inputs[i]
            shape = shape_list(y)
            x = x[:, :shape[1], :shape[2], :]
            x = layer_norm(x + y)
            x = add_timing_signal_nd(x)

        # Cut down to original size.
        x = x[:, :inputs_shape[1], :inputs_shape[2], :]
        x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)

        x = tf.layers.dense(x, self.depth, name="logits")

        reward_pred = None
        if self.has_rewards:
          # Reward prediction based on middle and final logits.
          reward_pred = tf.concat([x_mid, x_fin], axis=-1)
          reward_pred = tf.nn.relu(tf.layers.dense(
              reward_pred, 128, name="reward_pred"))
          reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
          reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims

        return x, reward_pred, policy_pred, value_pred

    def imagine(self, sess, obs, action):
        action = np.array(action)
        action = np.reshape(action, (1, 1))
        obs = obs.reshape(1, self.width, self.height, self.depth)
        next_pred_ob = sess.run(self.state_pred, feed_dict={self.states_ph : obs, self.actions_ph : action})
        next_pred_ob = next_pred_ob.reshape(self.width, self.height, self.depth)
        next_pred_ob = np.rint(next_pred_ob)
        return next_pred_ob

    def train(self, world_model_path):
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            losses = []
            all_rewards = []
            save_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='env_model')
            saver = tf.train.Saver(var_list=save_vars)

            train_writer = tf.summary.FileWriter('./env_logs/train/', graph=sess.graph)
            summary_op = tf.summary.merge_all()

            if(self.n_envs == 1):
                envs = make_env()()
            else:
                envs = [make_env() for i in range(self.n_envs)]
                envs = SubprocVecEnv(envs)

            for idx, states, actions, rewards, next_states, dones in tqdm(
                self.generate_data(envs), total=self.max_ep_len):
                actions = np.array(actions)
                actions = np.reshape(actions, (-1, 1))

                if(self.has_rewards):
                    target_reward = reward_to_target(rewards)
                    loss, reward_loss, state_loss, summary, _ = sess.run([self.loss, self.reward_loss, self.state_loss,
                        summary_op, self.opt], feed_dict={
                        self.states_ph: states,
                        self.actions_ph: actions,
                        self.target_states: next_states,
                        self.target_rewards: target_reward
                    })
                else :
                    loss, summary, _ = sess.run([self.loss, summary_op, self.opt], feed_dict={
                        self.states_ph: states,
                        self.actions_ph: actions,
                        self.target_states: next_states,
                    })

                if idx % self.log_interval == 0:
                    if(self.has_rewards):
                        print('%i => Loss : %.4f, Reward Loss : %.4f, Image Loss : %.4f' % (idx, loss, reward_loss, state_loss))
                    else :
                        print('%i => Loss : %.4f' % (idx, loss))
                    saver.save(sess, '{}/env_model.ckpt'.format(world_model_path))
                    print('Environment model saved')

                train_writer.add_summary(summary, idx)
            envs.close()

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from world_model import *
from config import argparser
from utils import make_env as env_fn, printstar

g_env_model = None

def main(config):
    global env_fn
    env = env_fn()()
    env.reset()

    action_dim = 5
    ob_shape = env.observation_space.shape
    world_model_path = os.path.expanduser(os.path.join(config.model_dir, config.world_model_type + "_" + config.world_model_path))

    if(config.train_world_model):
        env_model = EnvModel(ob_shape, action_dim, config)
        if(not os.path.exists(world_model_path)):
            os.mkdir(world_model_path)
        printstar("Training World Model")
        env_model.train(world_model_path)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if(config.eval_world_model):
            env_model = cached_world_model(sess, ob_shape, action_dim, config, world_model_path + '/env_model.ckpt')
            evaluate_world_model(env, sess, env_model, config)

def evaluate_world_model(env, sess, env_model, config, policy=None):
    printstar("Testing World Model")
    obs = env.reset()
    for t in range(config.max_eval_iters):
        if(policy is None):
                action = env.action_space.sample()
        else:
            action = policy(obs)

        next_pred_ob = env_model.imagine(sess, obs, action)
        imgplot = plt.imshow(next_pred_ob)
        plt.savefig('./figs/world_model_eval.png')

        env.render()
        obs, reward, dones, info = env.step(action)
        inp = input("Press 0 to exit : ")
        if(inp == "0"):
            break

if __name__ == '__main__':
    config = argparser()
    mp.set_start_method('spawn', force=True)
    main(config)

In [None]:
import os
import argparse

def str2bool(v):
    return v.lower() == 'true'

def str2list(v):
    if not v:
        return v
    else:
        return [v_ for v_ in v.split(',')]

def argparser():
    parser = argparse.ArgumentParser("Model-Based Reinforcement Learning for Atari",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--train_world_model', type=str2bool, default=True)
    parser.add_argument('--eval_world_model', type=str2bool, default=True)

    parser.add_argument('--num_rewards', type=int, default=1)
    parser.add_argument('--n_envs', type=int, default=16)
    parser.add_argument('--is_policy', type=str2bool, default=False)
    parser.add_argument('--has_rewards', type=str2bool, default=False)
    parser.add_argument('--hidden_size', type=int, default=64)
    parser.add_argument('--n_layers', type=int, default=6)
    parser.add_argument('--dropout_p', type=float, default=0.1)
    parser.add_argument('--max_ep_len', type=int, default=500000)
    parser.add_argument('--l2_clip', type=float, default=0.0)
    parser.add_argument('--softmax_clip', type=float, default=0.03)
    parser.add_argument('--reward_coeff', type=float, default=0.1)
    parser.add_argument('--log_interval', type=int, default=100)
    parser.add_argument('--activation_fn', type=str, default='relu', choices=['relu', 'tanh'])

    parser.add_argument('--log_dir', type=str, default='./log')
    parser.add_argument('--model_dir', type=str, default='./models')
    parser.add_argument('--world_model_path', type=str, default="CarRacingWorldModel")

    parser.add_argument('--total_timesteps', type=int, default=int(1e6))
    parser.add_argument('--max_eval_iters', type=int, default=int(1e3))

    parser.add_argument('--render', type=str2bool, default=True, help='Render frames')
    parser.add_argument('--debug', type=str2bool, default=False, help='See debugging info')

    args = parser.parse_args()
    return args

In [None]:
# Code is from OpenAI Baseline and Tensor2Tensor

import itertools
import numpy as np
from gym.envs.box2d import CarRacing
import multiprocessing as mp

def printstar(string, num_stars=50):
    print("*" * num_stars)
    print(string)
    print("*" * num_stars)

def make_env():
    def _thunk():
        env = CarRacing(grayscale=0, show_info_panel=0, discretize_actions="hard", frames_per_state=1, num_lanes=1, num_tracks=1)
        return env
    return _thunk

def worker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    env = env_fn_wrapper.x()
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, reward, done, info = env.step(data)
            if done:
                ob = env.reset()
            remote.send((ob, reward, done, info))
        elif cmd == 'reset':
            ob = env.reset()
            remote.send(ob)
        elif cmd == 'reset_task':
            ob = env.reset_task()
            remote.send(ob)
        elif cmd == 'close':
            remote.close()
            break
        elif cmd == 'get_spaces':
            remote.send((env.observation_space, env.action_space))
        else:
            raise NotImplementedError

class VecEnv(object):
    """
    An abstract asynchronous, vectorized environment.
    """
    def __init__(self, num_envs, observation_space, action_space):
        self.num_envs = num_envs
        self.observation_space = observation_space
        self.action_space = action_space

    def reset(self):
        """
        Reset all the environments and return an array of
        observations, or a tuple of observation arrays.
        If step_async is still doing work, that work will
        be cancelled and step_wait() should not be called
        until step_async() is invoked again.
        """
        pass

    def step_async(self, actions):
        """
        Tell all the environments to start taking a step
        with the given actions.
        Call step_wait() to get the results of the step.
        You should not call this if a step_async run is
        already pending.
        """
        pass

    def step_wait(self):
        """
        Wait for the step taken with step_async().
        Returns (obs, rews, dones, infos):
         - obs: an array of observations, or a tuple of
                arrays of observations.
         - rews: an array of rewards
         - dones: an array of "episode done" booleans
         - infos: a sequence of info objects
        """
        pass

    def close(self):
        """
        Clean up the environments' resources.
        """
        pass

    def step(self, actions):
        self.step_async(actions)
        return self.step_wait()


class CloudpickleWrapper(object):
    """
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
    """
    def __init__(self, x):
        self.x = x
    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)
    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)

class SubprocVecEnv(VecEnv):
    def __init__(self, env_fns, spaces=None):
        """
        envs: list of gym environments to run in subprocesses
        """
        self.waiting = False
        self.closed = False
        nenvs = len(env_fns)
        self.nenvs = nenvs
        self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in range(nenvs)])
        self.ps = [mp.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
            for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for p in self.ps:
            p.daemon = True # if the main process crashes, we should not cause things to hang
            p.start()
        for remote in self.work_remotes:
            remote.close()

        self.remotes[0].send(('get_spaces', None))
        observation_space, action_space = self.remotes[0].recv()
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)

    def step_async(self, actions):
        if(type(actions) == int):
            for remote in self.remotes:
                remote.send(('step', actions))
        else:
            for remote, action in zip(self.remotes, actions):
                remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
            self.closed = True

    def __len__(self):
        return self.nenvs



def shape_list(x):
  """Return list of dims, statically where possible."""
  x = tf.convert_to_tensor(x)

  # If unknown rank, return dynamic shape
  if x.get_shape().dims is None:
    return tf.shape(x)

  static = x.get_shape().as_list()
  shape = tf.shape(x)

  ret = []
  for i, dim in enumerate(static):
    if dim is None:
      dim = shape[i]
    ret.append(dim)
  return ret

def to_float(x):
  """Cast x to float; created because tf.to_float is deprecated."""
  return tf.cast(x, tf.float32)

def cast_like(x, y):
  """Cast x to y's dtype, if necessary."""
  x = tf.convert_to_tensor(x)
  y = tf.convert_to_tensor(y)

  if x.dtype.base_dtype == y.dtype.base_dtype:
    return x

  cast_x = tf.cast(x, y.dtype)
  if cast_x.device != x.device:
    x_name = "(eager Tensor)"
    try:
      x_name = x.name
    except AttributeError:
      pass
    tf.logging.warning("Cast for %s may induce copy from '%s' to '%s'", x_name,
                       x.device, cast_x.device)
  return cast_x

def layer_norm_vars(filters):
  """Create Variables for layer norm."""
  scale = tf.get_variable(
      "layer_norm_scale", [filters], initializer=tf.ones_initializer())
  bias = tf.get_variable(
      "layer_norm_bias", [filters], initializer=tf.zeros_initializer())
  return scale, bias


def layer_norm_compute(x, epsilon, scale, bias, layer_collection=None):
  """Layer norm raw computation."""

  # Save these before they get converted to tensors by the casting below
  params = (scale, bias)

  epsilon, scale, bias = [cast_like(t, x) for t in [epsilon, scale, bias]]
  mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
  variance = tf.reduce_mean(
      tf.squared_difference(x, mean), axis=[-1], keepdims=True)
  norm_x = (x - mean) * tf.rsqrt(variance + epsilon)

  output = norm_x * scale + bias


  return output

def layer_norm(x,
               filters=None,
               epsilon=1e-6,
               name=None,
               reuse=None,
               layer_collection=None):
  """Layer normalize the tensor x, averaging over the last dimension."""
  if filters is None:
    filters = shape_list(x)[-1]
  with tf.variable_scope(
      name, default_name="layer_norm", values=[x], reuse=reuse):
    scale, bias = layer_norm_vars(filters)
    return layer_norm_compute(x, epsilon, scale, bias,
                              layer_collection=layer_collection)

def standardize_images(x):
  """Image standardization on batches and videos."""
  with tf.name_scope("standardize_images", values=[x]):
    x_shape = shape_list(x)
    x = to_float(tf.reshape(x, [-1] + x_shape[-3:]))
    x_mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    x_variance = tf.reduce_mean(
        tf.squared_difference(x, x_mean), axis=[1, 2], keepdims=True)
    num_pixels = to_float(x_shape[-2] * x_shape[-3])
    x = (x - x_mean) / tf.maximum(tf.sqrt(x_variance), tf.rsqrt(num_pixels))
    return tf.reshape(x, x_shape)

def pad_to_same_length(x, y, final_length_divisible_by=1, axis=1):
  """Pad tensors x and y on axis 1 so that they have the same length."""
  if axis not in [1, 2]:
    raise ValueError("Only axis=1 and axis=2 supported for now.")
  with tf.name_scope("pad_to_same_length", values=[x, y]):
    x_length = shape_list(x)[axis]
    y_length = shape_list(y)[axis]
    if (isinstance(x_length, int) and isinstance(y_length, int) and
        x_length == y_length and final_length_divisible_by == 1):
      return x, y
    max_length = tf.maximum(x_length, y_length)
    if final_length_divisible_by > 1:
      # Find the nearest larger-or-equal integer divisible by given number.
      max_length += final_length_divisible_by - 1
      max_length //= final_length_divisible_by
      max_length *= final_length_divisible_by
    length_diff1 = max_length - x_length
    length_diff2 = max_length - y_length

    def padding_list(length_diff, arg):
      if axis == 1:
        return [[[0, 0], [0, length_diff]],
                tf.zeros([tf.rank(arg) - 2, 2], dtype=tf.int32)]
      return [[[0, 0], [0, 0], [0, length_diff]],
              tf.zeros([tf.rank(arg) - 3, 2], dtype=tf.int32)]

    paddings1 = tf.concat(padding_list(length_diff1, x), axis=0)
    paddings2 = tf.concat(padding_list(length_diff2, y), axis=0)
    res_x = tf.pad(x, paddings1)
    res_y = tf.pad(y, paddings2)
    # Static shapes are the same except for axis=1.
    x_shape = x.shape.as_list()
    x_shape[axis] = None
    res_x.set_shape(x_shape)
    y_shape = y.shape.as_list()
    y_shape[axis] = None
    res_y.set_shape(y_shape)
    return res_x, res_y

def make_even_size(x):
  """Pad x to be even-sized on axis 1 and 2, but only if necessary."""
  x_shape = x.get_shape().as_list()
  assert len(x_shape) > 2, "Only 3+-dimensional tensors supported."
  shape = [dim if dim is not None else -1 for dim in x_shape]
  new_shape = x_shape  # To make sure constant shapes remain constant.
  if x_shape[1] is not None:
    new_shape[1] = 2 * int(math.ceil(x_shape[1] * 0.5))
  if x_shape[2] is not None:
    new_shape[2] = 2 * int(math.ceil(x_shape[2] * 0.5))
  if shape[1] % 2 == 0 and shape[2] % 2 == 0:
    return x
  if shape[1] % 2 == 0:
    x, _ = pad_to_same_length(x, x, final_length_divisible_by=2, axis=2)
    x.set_shape(new_shape)
    return x
  if shape[2] % 2 == 0:
    x, _ = pad_to_same_length(x, x, final_length_divisible_by=2, axis=1)
    x.set_shape(new_shape)
    return x
  x, _ = pad_to_same_length(x, x, final_length_divisible_by=2, axis=1)
  x, _ = pad_to_same_length(x, x, final_length_divisible_by=2, axis=2)
  x.set_shape(new_shape)
  return x


def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4):
  """Adds a bunch of sinusoids of different frequencies to a Tensor.

  Each channel of the input Tensor is incremented by a sinusoid of a different
  frequency and phase in one of the positional dimensions.

  This allows attention to learn to use absolute and relative positions.
  Timing signals should be added to some precursors of both the query and the
  memory inputs to attention.

  The use of relative position is possible because sin(a+b) and cos(a+b) can be
  expressed in terms of b, sin(a) and cos(a).

  x is a Tensor with n "positional" dimensions, e.g. one dimension for a
  sequence or two dimensions for an image

  We use a geometric sequence of timescales starting with
  min_timescale and ending with max_timescale.  The number of different
  timescales is equal to channels // (n * 2). For each timescale, we
  generate the two sinusoidal signals sin(timestep/timescale) and
  cos(timestep/timescale).  All of these sinusoids are concatenated in
  the channels dimension.

  Args:
    x: a Tensor with shape [batch, d1 ... dn, channels]
    min_timescale: a float
    max_timescale: a float

  Returns:
    a Tensor the same shape as x.
  """
  num_dims = len(x.get_shape().as_list()) - 2
  channels = shape_list(x)[-1]
  num_timescales = channels // (num_dims * 2)
  log_timescale_increment = (
      math.log(float(max_timescale) / float(min_timescale)) /
      (tf.to_float(num_timescales) - 1))
  inv_timescales = min_timescale * tf.exp(
      tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
  for dim in range(num_dims):
    length = shape_list(x)[dim + 1]
    position = tf.to_float(tf.range(length))
    scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(
        inv_timescales, 0)
    signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
    prepad = dim * 2 * num_timescales
    postpad = channels - (dim + 1) * 2 * num_timescales
    signal = tf.pad(signal, [[0, 0], [prepad, postpad]])
    for _ in range(1 + dim):
      signal = tf.expand_dims(signal, 0)
    for _ in range(num_dims - 1 - dim):
      signal = tf.expand_dims(signal, -2)
    x += signal
  return x
