Skip to content

Commit

Permalink
Implement Algorithm L for Reservoir Sampling in Enum
Browse files Browse the repository at this point in the history
This optimizes Enum.random/1 and Enum.take_random/2
to be 6.3x times faster and use 2.7x less memory.
  • Loading branch information
josevalim committed Nov 3, 2023
1 parent a34cd28 commit 8e9cbfc
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 96 deletions.
159 changes: 84 additions & 75 deletions lib/elixir/lib/enum.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2362,12 +2362,6 @@ defmodule Enum do
the random value. Check its documentation for setting a
different random algorithm or a different seed.
The implementation is based on the
[reservoir sampling](https://en.wikipedia.org/wiki/Reservoir_sampling#Relation_to_Fisher-Yates_shuffle)
algorithm.
It assumes that the sample being returned can fit into memory;
the input `enumerable` doesn't have to, as it is traversed just once.
If a range is passed into the function, this function will pick a
random value between the range limits, without traversing the whole
range (thus executing in constant time and constant memory).
Expand All @@ -2386,6 +2380,12 @@ defmodule Enum do
iex> Enum.random(1..1_000)
309
## Implementation
The random functions in this module implement reservoir sampling,
which allows them to sample infinite collections. In particular,
we implement Algorithm L, as described in by Kim-Hung Li in
"Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
"""
@spec random(t) :: element
def random(enumerable)
Expand Down Expand Up @@ -2902,11 +2902,11 @@ defmodule Enum do
the default from Erlang/OTP 22:
# Although not necessary, let's seed the random algorithm
iex> :rand.seed(:exsss, {1, 2, 3})
iex> Enum.shuffle([1, 2, 3])
[3, 2, 1]
iex> :rand.seed(:exsss, {11, 22, 33})
iex> Enum.shuffle([1, 2, 3])
[2, 1, 3]
iex> Enum.shuffle([1, 2, 3])
[2, 3, 1]
"""
@spec shuffle(t) :: list
Expand All @@ -2916,9 +2916,12 @@ defmodule Enum do
[{:rand.uniform(), x} | acc]
end)

shuffle_unwrap(:lists.keysort(1, randomized), [])
shuffle_unwrap(:lists.keysort(1, randomized))
end

defp shuffle_unwrap([{_, h} | rest]), do: [h | shuffle_unwrap(rest)]
defp shuffle_unwrap([]), do: []

@doc """
Returns a subset list of the given `enumerable` by `index_range`.
Expand Down Expand Up @@ -3588,100 +3591,114 @@ defmodule Enum do
# Although not necessary, let's seed the random algorithm
iex> :rand.seed(:exsss, {1, 2, 3})
iex> Enum.take_random(1..10, 2)
[3, 1]
[6, 1]
iex> Enum.take_random(?a..?z, 5)
~c"mikel"
~c"bkzmt"
"""
@spec take_random(t, non_neg_integer) :: list
def take_random(enumerable, count)
def take_random(_enumerable, 0), do: []

def take_random([], _), do: []
def take_random([h | t], 1), do: take_random_list_one(t, h, 1)

def take_random(enumerable, 1) do
enumerable
|> reduce([], fn
x, [current | index] ->
if :rand.uniform(index + 1) == 1 do
[x | index + 1]
else
[current | index + 1]
end
|> reduce({0, 0, 1.0, nil}, fn
elem, {idx, idx, w, _current} ->
{jdx, w} = take_jdx_w(idx, w, 1)
{idx + 1, jdx, w, elem}

x, [] ->
[x | 1]
_elem, {idx, jdx, w, current} ->
{idx + 1, jdx, w, current}
end)
|> case do
[] -> []
[current | _index] -> [current]
{0, 0, 1.0, nil} -> []
{_idx, _jdx, _w, current} -> [current]
end
end

def take_random(enumerable, count) when is_integer(count) and count in 0..128 do
def take_random(enumerable, count) when count in 0..128 do
sample = Tuple.duplicate(nil, count)

reducer = fn elem, {idx, sample} ->
jdx = random_index(idx)
reducer = fn
elem, {idx, jdx, w, sample} when idx < count ->
rand = take_index(idx)
sample = sample |> put_elem(idx, elem(sample, rand)) |> put_elem(rand, elem)

cond do
idx < count ->
value = elem(sample, jdx)
{idx + 1, put_elem(sample, idx, value) |> put_elem(jdx, elem)}
if idx == jdx do
{jdx, w} = take_jdx_w(idx, w, count)
{idx + 1, jdx, w, sample}
else
{idx + 1, jdx, w, sample}
end

jdx < count ->
{idx + 1, put_elem(sample, jdx, elem)}
elem, {idx, idx, w, sample} ->
pos = :rand.uniform(count) - 1
{jdx, w} = take_jdx_w(idx, w, count)
{idx + 1, jdx, w, put_elem(sample, pos, elem)}

true ->
{idx + 1, sample}
end
_elem, {idx, jdx, w, sample} ->
{idx + 1, jdx, w, sample}
end

{size, sample} = reduce(enumerable, {0, sample}, reducer)
sample |> Tuple.to_list() |> take(Kernel.min(count, size))
{size, _, _, sample} = reduce(enumerable, {0, count - 1, 1.0, sample}, reducer)

if count < size do
Tuple.to_list(sample)
else
take_tupled(sample, size, [])
end
end

def take_random(enumerable, count) when is_integer(count) and count >= 0 do
reducer = fn elem, {idx, sample} ->
jdx = random_index(idx)

cond do
idx < count ->
value = Map.get(sample, jdx)
{idx + 1, Map.put(sample, idx, value) |> Map.put(jdx, elem)}
reducer = fn
elem, {idx, jdx, w, sample} when idx < count ->
rand = take_index(idx)
sample = sample |> Map.put(idx, Map.get(sample, rand)) |> Map.put(rand, elem)

if idx == jdx do
{jdx, w} = take_jdx_w(idx, w, count)
{idx + 1, jdx, w, sample}
else
{idx + 1, jdx, w, sample}
end

jdx < count ->
{idx + 1, Map.put(sample, jdx, elem)}
elem, {idx, idx, w, sample} ->
pos = :rand.uniform(count) - 1
{jdx, w} = take_jdx_w(idx, w, count)
{idx + 1, jdx, w, %{sample | pos => elem}}

true ->
{idx + 1, sample}
end
_elem, {idx, jdx, w, sample} ->
{idx + 1, jdx, w, sample}
end

{size, sample} = reduce(enumerable, {0, %{}}, reducer)
take_random(sample, Kernel.min(count, size), [])
{size, _, _, sample} = reduce(enumerable, {0, count - 1, 1.0, %{}}, reducer)
take_mapped(sample, Kernel.min(count, size), [])
end

defp take_random(_sample, 0, acc), do: acc

defp take_random(sample, position, acc) do
position = position - 1
take_random(sample, position, [Map.get(sample, position) | acc])
@compile {:inline, take_jdx_w: 3, take_index: 1}
defp take_jdx_w(idx, w, count) do
w = w * :math.exp(:math.log(:rand.uniform()) / count)
jdx = idx + floor(:math.log(:rand.uniform()) / :math.log(1 - w)) + 1
{jdx, w}
end

defp take_random_list_one([h | t], current, index) do
if :rand.uniform(index + 1) == 1 do
take_random_list_one(t, h, index + 1)
else
take_random_list_one(t, current, index + 1)
end
defp take_index(0), do: 0
defp take_index(idx), do: :rand.uniform(idx + 1) - 1

defp take_tupled(_sample, 0, acc), do: acc

defp take_tupled(sample, position, acc) do
position = position - 1
take_tupled(sample, position, [elem(sample, position) | acc])
end

defp take_random_list_one([], current, _), do: [current]
defp take_mapped(_sample, 0, acc), do: acc

defp random_index(0), do: 0
defp random_index(idx), do: :rand.uniform(idx + 1) - 1
defp take_mapped(sample, position, acc) do
position = position - 1
take_mapped(sample, position, [Map.fetch!(sample, position) | acc])
end

@doc """
Takes the elements from the beginning of the `enumerable` while `fun` returns
Expand Down Expand Up @@ -4439,14 +4456,6 @@ defmodule Enum do
[acc | scan_list(rest, acc, fun)]
end

## shuffle

defp shuffle_unwrap([{_, h} | enumerable], t) do
shuffle_unwrap(enumerable, [h | t])
end

defp shuffle_unwrap([], t), do: t

## slice

defp slice_forward(enumerable, start, amount, step) when start < 0 do
Expand Down
39 changes: 18 additions & 21 deletions lib/elixir/test/elixir/enum_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ defmodule EnumTest do
test "shuffle/1" do
# set a fixed seed so the test can be deterministic
:rand.seed(:exsss, {1374, 347_975, 449_264})
assert Enum.shuffle([1, 2, 3, 4, 5]) == [1, 3, 4, 5, 2]
assert Enum.shuffle([1, 2, 3, 4, 5]) == [2, 5, 4, 3, 1]
end

test "slice/2" do
Expand Down Expand Up @@ -1377,16 +1377,16 @@ defmodule EnumTest do
seed1 = {1406, 407_414, 139_258}
seed2 = {1406, 421_106, 567_597}
:rand.seed(:exsss, seed1)
assert Enum.take_random([1, 2, 3], 1) == [3]
assert Enum.take_random([1, 2, 3], 2) == [3, 2]
assert Enum.take_random([1, 2, 3], 1) == [2]
assert Enum.take_random([1, 2, 3], 2) == [2, 3]
assert Enum.take_random([1, 2, 3], 3) == [3, 1, 2]
assert Enum.take_random([1, 2, 3], 4) == [1, 3, 2]
assert Enum.take_random([1, 2, 3], 4) == [2, 3, 1]
:rand.seed(:exsss, seed2)
assert Enum.take_random([1, 2, 3], 1) == [1]
assert Enum.take_random([1, 2, 3], 2) == [3, 1]
assert Enum.take_random([1, 2, 3], 3) == [3, 1, 2]
assert Enum.take_random([1, 2, 3], 4) == [2, 1, 3]
assert Enum.take_random([1, 2, 3], 129) == [2, 3, 1]
assert Enum.take_random([1, 2, 3], 3) == [2, 3, 1]
assert Enum.take_random([1, 2, 3], 4) == [3, 2, 1]
assert Enum.take_random([1, 2, 3], 129) == [2, 1, 3]

# assert that every item in the sample comes from the input list
list = for _ <- 1..100, do: make_ref()
Expand Down Expand Up @@ -2071,8 +2071,8 @@ defmodule EnumTest.Range do
test "shuffle/1" do
# set a fixed seed so the test can be deterministic
:rand.seed(:exsss, {1374, 347_975, 449_264})
assert Enum.shuffle(1..5) == [1, 3, 4, 5, 2]
assert Enum.shuffle(1..10//2) == [3, 9, 7, 1, 5]
assert Enum.shuffle(1..5) == [2, 5, 4, 3, 1]
assert Enum.shuffle(1..10//2) == [5, 1, 7, 9, 3]
end

test "slice/2" do
Expand Down Expand Up @@ -2316,30 +2316,27 @@ defmodule EnumTest.Range do
seed1 = {1406, 407_414, 139_258}
seed2 = {1406, 421_106, 567_597}
:rand.seed(:exsss, seed1)
assert Enum.take_random(1..3, 1) == [3]
assert Enum.take_random(1..3, 1) == [2]
:rand.seed(:exsss, seed1)
assert Enum.take_random(1..3, 2) == [3, 1]
:rand.seed(:exsss, seed1)
assert Enum.take_random(1..3, 3) == [3, 1, 2]
:rand.seed(:exsss, seed1)
assert Enum.take_random(1..3, 4) == [3, 1, 2]
:rand.seed(:exsss, seed1)
assert Enum.take_random(3..1//-1, 1) == [1]
assert Enum.take_random(1..3, 5) == [3, 1, 2]
:rand.seed(:exsss, seed1)
assert Enum.take_random(3..1//-1, 1) == [2]
:rand.seed(:exsss, seed2)
assert Enum.take_random(1..3, 1) == [1]
:rand.seed(:exsss, seed2)
assert Enum.take_random(1..3, 2) == [1, 3]
assert Enum.take_random(1..3, 2) == [3, 2]
:rand.seed(:exsss, seed2)
assert Enum.take_random(1..3, 3) == [1, 3, 2]
:rand.seed(:exsss, seed2)
assert Enum.take_random(1..3, 4) == [1, 3, 2]

# make sure optimizations don't change fixed seeded tests
:rand.seed(:exsss, {101, 102, 103})
one = Enum.take_random(1..100, 1)
:rand.seed(:exsss, {101, 102, 103})
two = Enum.take_random(1..100, 2)
assert hd(one) == hd(two)
:rand.seed(:exsss, seed2)
assert Enum.take_random(1..3, 5) == [1, 3, 2]
end

test "take_while/2" do
Expand Down Expand Up @@ -2425,7 +2422,7 @@ defmodule EnumTest.Map do
seed1 = {1406, 407_414, 139_258}
seed2 = {1406, 421_106, 567_597}
:rand.seed(:exsss, seed1)
assert Enum.take_random(map, 1) == [x3]
assert Enum.take_random(map, 1) == [x2]
:rand.seed(:exsss, seed1)
assert Enum.take_random(map, 2) == [x3, x1]
:rand.seed(:exsss, seed1)
Expand All @@ -2435,7 +2432,7 @@ defmodule EnumTest.Map do
:rand.seed(:exsss, seed2)
assert Enum.take_random(map, 1) == [x1]
:rand.seed(:exsss, seed2)
assert Enum.take_random(map, 2) == [x1, x3]
assert Enum.take_random(map, 2) == [x3, x2]
:rand.seed(:exsss, seed2)
assert Enum.take_random(map, 3) == [x1, x3, x2]
:rand.seed(:exsss, seed2)
Expand Down

0 comments on commit 8e9cbfc

Please sign in to comment.