In [1]:
using PyCall

In [2]:
gym = pyimport("gym")

PyObject <module 'gym' from '/home/d9w/.julia/conda/3/lib/python3.7/site-packages/gym/__init__.py'>

In [3]:
cartpole = gym.make("CartPole-v1")

PyObject <TimeLimit<CartPoleEnv<CartPole-v1>>>

In [4]:
state = cartpole.reset()
action = cartpole.action_space.sample()
next_state, reward, done, _ = cartpole.step(action)

([0.007897900672134911, 0.1712120802856951, -0.04306554378875549, -0.3142409407529665], 1.0, false, Dict{Any,Any}())

In [5]:
typeof(state), typeof(action), typeof(reward), typeof(next_state)

(Array{Float64,1}, Int64, Float64, Array{Float64,1})

In [6]:
replay_buffer_size = Int(1e6)
nb_samples = Int(2e6)
nb_batches = Int(1e4)
batch_size = 50

50

In [7]:
abstract type Buffer end

In [8]:
struct Transition
    state::Array{Float64}
    action::Int64
    reward::Float64
    next_state::Array{Float64}
end

In [9]:
t = Transition(state, action, reward, next_state)

Transition([0.008387860821805662, -0.02449800748353751, -0.04289879348189916, -0.008337515342816536], 1, 1.0, [0.007897900672134911, 0.1712120802856951, -0.04306554378875549, -0.3142409407529665])

In [10]:
tqdm = pyimport("tqdm")

function test_insertion_tqdm(buffer::Buffer, nb_samples::Int, transition::Transition)
    state = cartpole.reset()
    for _ in tqdm.trange(nb_samples)
        append(buffer, transition)
    end
end

function test_sampling_tqdm(buffer::Buffer, nb_batches::Int)
    for _ in tqdm.trange(nb_samples)
        sample(buffer, batch_size)
    end
end

test_sampling_tqdm (generic function with 1 method)

In [11]:
function test_insertion_timev(buffer::Buffer, nb_samples::Int, transition::Transition)
    println("Insertion of ", nb_samples, " samples:")
    @timev for i in 1:nb_samples
        append(buffer, transition)
    end
end

function test_sampling_timev(buffer::Buffer, nb_batches::Int)
    println("Sampling of ", nb_samples, " batches:")
    @timev for i in 1:nb_batches
        sample(buffer, batch_size)
    end
end  

test_sampling_timev (generic function with 1 method)

In [12]:
struct ReplayBuffer1 <: Buffer
    data::Array{Transition}
    capacity::Int64
end

function ReplayBuffer1(capacity::Int64)
    ReplayBuffer1(Array{Transition}(undef, 0), capacity)
end

function append(buffer::ReplayBuffer1, t::Transition)
    if length(buffer.data) < buffer.capacity
        push!(buffer.data, t)
    end    
end

function sample(buffer::ReplayBuffer1, batch_size::Int)
    rand(buffer.data, batch_size)
end

sample (generic function with 1 method)

In [13]:
memory = ReplayBuffer1(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples, t)
test_sampling_tqdm(memory, nb_batches)
# buffer is already full for the next tests?
#memory = ReplayBuffer1(replay_buffer_size)
test_insertion_timev(memory, nb_samples, t)
test_sampling_timev(memory, nb_batches)

Insertion of 2000000 samples:
  0.005091 seconds
elapsed time (ns): 5091034
Sampling of 

  0%|                                              | 0/2000000 [00:00<?, ?it/s]  4%|█▏                            | 75208/2000000 [00:00<00:02, 752078.92it/s]  9%|██▍                          | 170305/2000000 [00:00<00:02, 802425.28it/s] 13%|███▊                         | 260538/2000000 [00:00<00:02, 829992.60it/s] 17%|████▊                        | 333351/2000000 [00:00<00:02, 796561.44it/s] 22%|██████▍                      | 441049/2000000 [00:00<00:01, 864053.61it/s] 27%|███████▊                     | 541883/2000000 [00:00<00:01, 902808.39it/s] 31%|█████████                    | 625380/2000000 [00:00<00:01, 833272.12it/s] 36%|██████████▌                  | 726106/2000000 [00:00<00:01, 878806.49it/s] 42%|████████████▏                | 838411/2000000 [00:00<00:01, 940143.44it/s] 48%|█████████████▊               | 950709/2000000 [00:01<00:01, 988421.67it/s] 53%|██████████████▋             | 1050105/2000000 [00:01<00:01, 878707.85it/s] 58%|████████████████▏           | 1159

2000000 batches:
  0.018513 seconds (10.00 k allocations: 4.730 MiB)
elapsed time (ns): 18512566
bytes allocated:   4960000
pool allocs:       10000
