In [1]:
using PyCall

In [2]:
gym = pyimport("gym")

PyObject <module 'gym' from '/home/d9w/.julia/conda/3/lib/python3.7/site-packages/gym/__init__.py'>

In [3]:
cartpole = gym.make("CartPole-v1")

PyObject <TimeLimit<CartPoleEnv<CartPole-v1>>>

In [4]:
state = cartpole.reset()
action = cartpole.action_space.sample()
next_state, reward, done, _ = cartpole.step(action)

([0.04939737410730551, -0.16508445572993605, 0.03675858373086133, 0.3005931360049499], 1.0, false, Dict{Any,Any}())

In [5]:
typeof(state), typeof(action), typeof(reward), typeof(next_state)

(Array{Float64,1}, Int64, Float64, Array{Float64,1})

In [6]:
replay_buffer_size = Int(1e6)
nb_samples = Int(2e6)
nb_batches = Int(1e4)
batch_size = 50

50

In [7]:
abstract type Buffer end
abstract type Transition end

In [8]:
struct GoodTransition <: Transition
    state::Array{Float64}
    action::Int64
    reward::Float64
    next_state::Array{Float64}
end

In [9]:
struct BadTransition <: Transition
    state::Array
    action
    reward
    next_state::Array
end

In [10]:
t = GoodTransition(state, action, reward, next_state)
tbad = BadTransition(state, action, reward, next_state)

BadTransition([0.04878645848476218, 0.03054578112716652, 0.03682814909862235, -0.0034782683880511325], 0, 1.0, [0.04939737410730551, -0.16508445572993605, 0.03675858373086133, 0.3005931360049499])

In [11]:
tqdm = pyimport("tqdm")

function test_insertion_tqdm(buffer::Buffer, nb_samples::Int, transition::Transition)
    state = cartpole.reset()
    for _ in tqdm.trange(nb_samples)
        append(buffer, transition)
    end
end

function test_sampling_tqdm(buffer::Buffer, nb_batches::Int)
    for _ in tqdm.trange(nb_samples)
        sample(buffer, batch_size)
    end
end

test_sampling_tqdm (generic function with 1 method)

In [12]:
function test_insertion_timev(buffer::Buffer, nb_samples::Int, transition::Transition)
    println("Insertion of ", nb_samples, " samples:")
    @timev for i in 1:nb_samples
        append(buffer, transition)
    end
end

function test_sampling_timev(buffer::Buffer, nb_batches::Int)
    println("Sampling of ", nb_samples, " batches:")
    @timev for i in 1:nb_batches
        sample(buffer, batch_size)
    end
end  

test_sampling_timev (generic function with 1 method)

In [13]:
struct ReplayBuffer1 <: Buffer
    data::Vector{Transition}
    capacity::Int64
end

function ReplayBuffer1(capacity::Int64)
    ReplayBuffer1(Vector{Transition}(undef, 0), capacity)
end

function append(buffer::ReplayBuffer1, t::Transition)
    if length(buffer.data) == buffer.capacity
        popfirst!(buffer.data)
    end
    push!(buffer.data, t)
end

function sample(buffer::ReplayBuffer1, batch_size::Int)
    rand(buffer.data, batch_size)
end

sample (generic function with 1 method)

In [14]:
mutable struct ReplayBuffer2 <: Buffer
    data::Vector{Transition}
    capacity::Int64
    i::Int64
    filled::Bool
end

function ReplayBuffer2(capacity::Int64)
    ReplayBuffer2(Vector{Transition}(undef, capacity), capacity, 1, false)
end

function append(buffer::ReplayBuffer2, t::Transition)
    buffer.data[buffer.i] = t
    buffer.i += 1
    if buffer.i > buffer.capacity
        buffer.filled = true
        buffer.i = 1
    end
end

function sample(buffer::ReplayBuffer2, batch_size::Int)
    if buffer.filled
        println
        return rand(buffer.data, batch_size)
    else
        return rand(buffer.data[1:buffer.i-1], batch_size)
    end
end

sample (generic function with 2 methods)

In [15]:
memory = ReplayBuffer1(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples, t)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timev(memory, nb_samples, t)
test_sampling_timev(memory, nb_batches)

Insertion of 2000000 samples:
  

  0%|                                              | 0/2000000 [00:00<?, ?it/s]  4%|█                             | 74231/2000000 [00:00<00:02, 742305.40it/s]  8%|██▎                          | 161348/2000000 [00:00<00:02, 776774.13it/s] 12%|███▌                         | 249567/2000000 [00:00<00:02, 805653.38it/s] 16%|████▌                        | 315768/2000000 [00:00<00:02, 756414.15it/s] 21%|██████▏                      | 423185/2000000 [00:00<00:01, 830044.48it/s] 26%|███████▍                     | 516138/2000000 [00:00<00:01, 857578.47it/s] 30%|████████▋                    | 595405/2000000 [00:00<00:01, 787386.85it/s] 35%|██████████                   | 694134/2000000 [00:00<00:01, 838307.89it/s] 40%|███████████▌                 | 798765/2000000 [00:00<00:01, 891472.90it/s] 45%|█████████████▏               | 909761/2000000 [00:01<00:01, 947420.30it/s] 50%|██████████████              | 1005381/2000000 [00:01<00:01, 868808.49it/s] 55%|███████████████▌            | 1107

0.033539 seconds (1 allocation: 0 bytes)
elapsed time (ns): 33539214
realloc() calls:   1
Sampling of 2000000 batches:
  0.020159 seconds (10.00 k allocations: 4.730 MiB)
elapsed time (ns): 20159457
bytes allocated:   4960000
pool allocs:       10000


In [None]:
memory = ReplayBuffer1(replay_buffer_size)
test_insertion_timev(memory, nb_samples, t)
test_sampling_timev(memory, nb_batches)
test_insertion_tqdm(memory, nb_samples, t)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timev(memory, nb_samples, t)
test_sampling_timev(memory, nb_batches)

In [16]:
# with untyped transitions
memory = ReplayBuffer1(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples, tbad)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timev(memory, nb_samples, tbad)
test_sampling_timev(memory, nb_batches)

Insertion of 2000000 samples:
  0.052615 seconds (1 allocation: 0 bytes)
elapsed time (ns): 52615481
realloc() calls:   1
Sampling of 2000000 batches:
  0.013670 seconds (10.00 k allocations: 4.730 MiB)
elapsed time (ns): 13670004
bytes allocated:   4960000
pool allocs:       10000



  0%|                                              | 0/2000000 [00:00<?, ?it/s]  5%|█▍                            | 96933/2000000 [00:00<00:01, 969298.57it/s] 10%|██▊                          | 192634/2000000 [00:00<00:01, 965578.57it/s] 13%|███▋                         | 256898/2000000 [00:00<00:02, 793940.40it/s] 18%|█████▏                       | 358091/2000000 [00:00<00:01, 848793.69it/s] 23%|██████▌                      | 454360/2000000 [00:00<00:01, 880028.33it/s] 28%|████████                     | 552313/2000000 [00:00<00:01, 907689.18it/s] 32%|█████████▏                   | 634346/2000000 [00:00<00:01, 848770.31it/s] 37%|██████████▊                  | 746401/2000000 [00:00<00:01, 915374.35it/s] 43%|████████████▎                | 852144/2000000 [00:00<00:01, 953812.93it/s] 47%|█████████████▋               | 946372/2000000 [00:01<00:01, 866142.94it/s] 52%|██████████████▋             | 1048799/2000000 [00:01<00:01, 908204.15it/s] 58%|████████████████▏           | 115

In [17]:
memory = ReplayBuffer2(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples, t)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timev(memory, nb_samples, t)
test_sampling_timev(memory, nb_batches)

Insertion of 2000000 samples:
  0.003771 seconds
elapsed time (ns): 3770681
Sampling of 2000000 batches:
  0.018892 seconds (10.00 k allocations: 4.730 MiB)
elapsed time (ns): 18892422
bytes allocated:   4960000
pool allocs:       10000



  0%|                                              | 0/2000000 [00:00<?, ?it/s]  4%|█▏                            | 82021/2000000 [00:00<00:02, 820081.74it/s]  7%|█▉                           | 137281/2000000 [00:00<00:02, 716022.85it/s]  9%|██▋                          | 188477/2000000 [00:00<00:02, 639547.08it/s] 15%|████▏                        | 291029/2000000 [00:00<00:02, 720949.23it/s] 19%|█████▌                       | 386595/2000000 [00:00<00:02, 778282.50it/s] 23%|██████▊                      | 469075/2000000 [00:00<00:01, 791676.46it/s] 27%|███████▊                     | 542344/2000000 [00:00<00:02, 611690.94it/s] 32%|█████████▍                   | 648122/2000000 [00:00<00:01, 700288.60it/s] 37%|██████████▊                  | 748274/2000000 [00:00<00:01, 769744.13it/s] 42%|████████████▎                | 847087/2000000 [00:01<00:01, 824400.69it/s] 47%|█████████████▌               | 935017/2000000 [00:01<00:01, 664234.25it/s] 52%|██████████████▋             | 104

In [None]:
# with untyped transitions
memory = ReplayBuffer2(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples, tbad)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timev(memory, nb_samples, tbad)
test_sampling_timev(memory, nb_batches)