In [11]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('Algorithm.py'))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('Environment.py'))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('Replay_Buffer.py'))))
import Algorithm as ALGS
import Environment as ENVS
import Replay_Buffer


In [12]:
# import collections

# def initialize_buffer(config):
#     return SimpleExperienceBuffer(config['max_size'], config['batch_size'])

# class SimpleExperienceBuffer:
#     def __init__(self, capacity, batch_size):
#         self.buffer = collections.deque(maxlen=capacity)
#         self.batch_size = batch_size

#     def __len__(self):
#         return len(self.buffer)

#     def append(self, experience):
#         self.buffer.append(experience)

#     def sample(self):
#         indices = np.random.choice(len(self.buffer), self.batch_size)
#         states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
#         return np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), \
#                np.array(dones, dtype=np.uint8), np.array(next_states)

# Implementation Example

## First initialize stuff

In [13]:
config = {'Learner': {'type': 'DQN', 'episodes': 1}, 'Algorithm': {'algorithm': 'DQN', 'replay_buffer': True, 'learning_rate': 0.003, 'optimizer': 'Adam', 'loss_function': 'MSELoss', 'regularizer': 0, 'recurrence': 0, 'gamma': 0.99, 'beta': 0, 'epsilon_start': 1, 'epsilon_end': 0.02, 'epsilon_decay': 5e-05, 'c': 10000}, 'Environment': {'env_type': 'Gym', 'environment': 'CartPole-v0', 'action_space': 'discrete', 'observation_space': 'discrete', 'env_render': False, 'num_agents': 1}, 'Replay_Buffer': {'max_size': 100000, 'batch_size': 64, 'num_agents': 1}, 'Agent': {'num_agents': 1}, 'Network': {'algorithm': 'DQN', 'layers': 3, 'hidden_layer1': 'linear', 'hidden_size1': 50, 'activation_function1': 'relu', 'hidden_layer2': 'linear', 'hidden_size2': 100, 'activation_function2': 'relu', 'hidden_layer3': 'linear'}}

In [14]:
config

{'Learner': {'type': 'DQN', 'episodes': 1},
 'Algorithm': {'algorithm': 'DQN',
  'replay_buffer': True,
  'learning_rate': 0.003,
  'optimizer': 'Adam',
  'loss_function': 'MSELoss',
  'regularizer': 0,
  'recurrence': 0,
  'gamma': 0.99,
  'beta': 0,
  'epsilon_start': 1,
  'epsilon_end': 0.02,
  'epsilon_decay': 5e-05,
  'c': 10000},
 'Environment': {'env_type': 'Gym',
  'environment': 'CartPole-v0',
  'action_space': 'discrete',
  'observation_space': 'discrete',
  'env_render': False,
  'num_agents': 1},
 'Replay_Buffer': {'max_size': 100000, 'batch_size': 64, 'num_agents': 1},
 'Agent': {'num_agents': 1},
 'Network': {'algorithm': 'DQN',
  'layers': 3,
  'hidden_layer1': 'linear',
  'hidden_size1': 50,
  'activation_function1': 'relu',
  'hidden_layer2': 'linear',
  'hidden_size2': 100,
  'activation_function2': 'relu',
  'hidden_layer3': 'linear'}}

In [15]:
env = ENVS.initialize_env(config['Environment'])
env

<Environment.GymEnvironment at 0x7fe03d186cf8>

In [16]:
alg = ALGS.initialize_algorithm(env.get_observation_space(), env.get_action_space(), [config['Algorithm'], config['Agent'], config['Network']])
alg

<Algorithm.DQAlgorithm at 0x7fe03559b518>

In [17]:
agent = alg.create_agent()
agent

<Agent.DQAgent at 0x7fe03559b2e8>

In [18]:
buffer = Replay_Buffer.initialize_buffer(config['Replay_Buffer'], None, None, None)
buffer

<Replay_Buffer.SimpleExperienceBuffer at 0x7fe03559bd30>

## Now simulate some training

In [19]:
num_of_episodes = 50

for i in range(num_of_episodes):
    obs = env.reset()
    obs = env.get_observation()
    done = False
    while not done:
        action = alg.get_action(agent, obs, i)
        next_obs, reward, done = env.step(action)
        
        experience = [obs, action, reward, done, next_obs]
        buffer.append(experience)
        
        experience = buffer.sample()
        alg.update(agent, experience, i)

        obs = next_obs

[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828, 



[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828, 



[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828, 



[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828, 



[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828, 




[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,




[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]
[array([ 0.01092799, -0.04475828,



In [20]:
obs, actions, rewards, dones, next_obs = buffer.sample()
obs.shape, actions.shape, rewards.shape, dones.shape, next_obs.shape

[array([ 0.01092799, -0.04475828,  0.0063452 ,  0.04364473]), 1, 1.0, False, array([ 0.01003283,  0.15027211,  0.00721809, -0.24702949])]


((64, 4), (64,), (64,), (64,), (64, 4))