In [70]:
import tensorflow as tf
import numpy as np
import gym
from gym.spaces import Box, Discrete



In [3]:
class ReplayBuffer(object):
    """
    A simple FIFO experience replay buffer for all agents.
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.rews_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return dict(obs1=self.obs1_buf[idxs],
                    obs2=self.obs2_buf[idxs],
                    acts=self.acts_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])

In [79]:
class DQN(object):
    def __init__(self, **kwargs):
        self.init_hyperparameters(**kwargs)
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=self.hyperparameters['replay_size'])
        pass
    
    def init_hyperparameters(self, **kwargs):
        self.hyperparameters = dict(env='CartPole-v0', replay_size=100)
        for param in args:
            self.hyperparameters[param] = args[param]
        self.env = gym.make(self.hyperparameters['env'])
        self.obs_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.n if isinstance(self.env.action_space, Discrete) else self.env.action_space.shape[0]
            

In [80]:


agent = DQN(env='CartPole-v0', lr=0.01)
#agent.run()

In [81]:
agent2 = DQN(env='HalfCheetah-v2', lr=0.01)