From 8e9cbfcd8c219f9d3558158f1ebee5ec4fadd762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Fri, 3 Nov 2023 21:19:16 +0100 Subject: [PATCH] Implement Algorithm L for Reservoir Sampling in Enum This optimizes Enum.random/1 and Enum.take_random/2 to be 6.3x times faster and use 2.7x less memory. --- lib/elixir/lib/enum.ex | 159 ++++++++++++++------------- lib/elixir/test/elixir/enum_test.exs | 39 +++---- 2 files changed, 102 insertions(+), 96 deletions(-) diff --git a/lib/elixir/lib/enum.ex b/lib/elixir/lib/enum.ex index 9ded32a9224..a2efb8dffac 100644 --- a/lib/elixir/lib/enum.ex +++ b/lib/elixir/lib/enum.ex @@ -2362,12 +2362,6 @@ defmodule Enum do the random value. Check its documentation for setting a different random algorithm or a different seed. - The implementation is based on the - [reservoir sampling](https://en.wikipedia.org/wiki/Reservoir_sampling#Relation_to_Fisher-Yates_shuffle) - algorithm. - It assumes that the sample being returned can fit into memory; - the input `enumerable` doesn't have to, as it is traversed just once. - If a range is passed into the function, this function will pick a random value between the range limits, without traversing the whole range (thus executing in constant time and constant memory). @@ -2386,6 +2380,12 @@ defmodule Enum do iex> Enum.random(1..1_000) 309 + ## Implementation + + The random functions in this module implement reservoir sampling, + which allows them to sample infinite collections. In particular, + we implement Algorithm L, as described in by Kim-Hung Li in + "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". """ @spec random(t) :: element def random(enumerable) @@ -2902,11 +2902,11 @@ defmodule Enum do the default from Erlang/OTP 22: # Although not necessary, let's seed the random algorithm - iex> :rand.seed(:exsss, {1, 2, 3}) - iex> Enum.shuffle([1, 2, 3]) - [3, 2, 1] + iex> :rand.seed(:exsss, {11, 22, 33}) iex> Enum.shuffle([1, 2, 3]) [2, 1, 3] + iex> Enum.shuffle([1, 2, 3]) + [2, 3, 1] """ @spec shuffle(t) :: list @@ -2916,9 +2916,12 @@ defmodule Enum do [{:rand.uniform(), x} | acc] end) - shuffle_unwrap(:lists.keysort(1, randomized), []) + shuffle_unwrap(:lists.keysort(1, randomized)) end + defp shuffle_unwrap([{_, h} | rest]), do: [h | shuffle_unwrap(rest)] + defp shuffle_unwrap([]), do: [] + @doc """ Returns a subset list of the given `enumerable` by `index_range`. @@ -3588,100 +3591,114 @@ defmodule Enum do # Although not necessary, let's seed the random algorithm iex> :rand.seed(:exsss, {1, 2, 3}) iex> Enum.take_random(1..10, 2) - [3, 1] + [6, 1] iex> Enum.take_random(?a..?z, 5) - ~c"mikel" + ~c"bkzmt" """ @spec take_random(t, non_neg_integer) :: list def take_random(enumerable, count) def take_random(_enumerable, 0), do: [] - def take_random([], _), do: [] - def take_random([h | t], 1), do: take_random_list_one(t, h, 1) def take_random(enumerable, 1) do enumerable - |> reduce([], fn - x, [current | index] -> - if :rand.uniform(index + 1) == 1 do - [x | index + 1] - else - [current | index + 1] - end + |> reduce({0, 0, 1.0, nil}, fn + elem, {idx, idx, w, _current} -> + {jdx, w} = take_jdx_w(idx, w, 1) + {idx + 1, jdx, w, elem} - x, [] -> - [x | 1] + _elem, {idx, jdx, w, current} -> + {idx + 1, jdx, w, current} end) |> case do - [] -> [] - [current | _index] -> [current] + {0, 0, 1.0, nil} -> [] + {_idx, _jdx, _w, current} -> [current] end end - def take_random(enumerable, count) when is_integer(count) and count in 0..128 do + def take_random(enumerable, count) when count in 0..128 do sample = Tuple.duplicate(nil, count) - reducer = fn elem, {idx, sample} -> - jdx = random_index(idx) + reducer = fn + elem, {idx, jdx, w, sample} when idx < count -> + rand = take_index(idx) + sample = sample |> put_elem(idx, elem(sample, rand)) |> put_elem(rand, elem) - cond do - idx < count -> - value = elem(sample, jdx) - {idx + 1, put_elem(sample, idx, value) |> put_elem(jdx, elem)} + if idx == jdx do + {jdx, w} = take_jdx_w(idx, w, count) + {idx + 1, jdx, w, sample} + else + {idx + 1, jdx, w, sample} + end - jdx < count -> - {idx + 1, put_elem(sample, jdx, elem)} + elem, {idx, idx, w, sample} -> + pos = :rand.uniform(count) - 1 + {jdx, w} = take_jdx_w(idx, w, count) + {idx + 1, jdx, w, put_elem(sample, pos, elem)} - true -> - {idx + 1, sample} - end + _elem, {idx, jdx, w, sample} -> + {idx + 1, jdx, w, sample} end - {size, sample} = reduce(enumerable, {0, sample}, reducer) - sample |> Tuple.to_list() |> take(Kernel.min(count, size)) + {size, _, _, sample} = reduce(enumerable, {0, count - 1, 1.0, sample}, reducer) + + if count < size do + Tuple.to_list(sample) + else + take_tupled(sample, size, []) + end end def take_random(enumerable, count) when is_integer(count) and count >= 0 do - reducer = fn elem, {idx, sample} -> - jdx = random_index(idx) - - cond do - idx < count -> - value = Map.get(sample, jdx) - {idx + 1, Map.put(sample, idx, value) |> Map.put(jdx, elem)} + reducer = fn + elem, {idx, jdx, w, sample} when idx < count -> + rand = take_index(idx) + sample = sample |> Map.put(idx, Map.get(sample, rand)) |> Map.put(rand, elem) + + if idx == jdx do + {jdx, w} = take_jdx_w(idx, w, count) + {idx + 1, jdx, w, sample} + else + {idx + 1, jdx, w, sample} + end - jdx < count -> - {idx + 1, Map.put(sample, jdx, elem)} + elem, {idx, idx, w, sample} -> + pos = :rand.uniform(count) - 1 + {jdx, w} = take_jdx_w(idx, w, count) + {idx + 1, jdx, w, %{sample | pos => elem}} - true -> - {idx + 1, sample} - end + _elem, {idx, jdx, w, sample} -> + {idx + 1, jdx, w, sample} end - {size, sample} = reduce(enumerable, {0, %{}}, reducer) - take_random(sample, Kernel.min(count, size), []) + {size, _, _, sample} = reduce(enumerable, {0, count - 1, 1.0, %{}}, reducer) + take_mapped(sample, Kernel.min(count, size), []) end - defp take_random(_sample, 0, acc), do: acc - - defp take_random(sample, position, acc) do - position = position - 1 - take_random(sample, position, [Map.get(sample, position) | acc]) + @compile {:inline, take_jdx_w: 3, take_index: 1} + defp take_jdx_w(idx, w, count) do + w = w * :math.exp(:math.log(:rand.uniform()) / count) + jdx = idx + floor(:math.log(:rand.uniform()) / :math.log(1 - w)) + 1 + {jdx, w} end - defp take_random_list_one([h | t], current, index) do - if :rand.uniform(index + 1) == 1 do - take_random_list_one(t, h, index + 1) - else - take_random_list_one(t, current, index + 1) - end + defp take_index(0), do: 0 + defp take_index(idx), do: :rand.uniform(idx + 1) - 1 + + defp take_tupled(_sample, 0, acc), do: acc + + defp take_tupled(sample, position, acc) do + position = position - 1 + take_tupled(sample, position, [elem(sample, position) | acc]) end - defp take_random_list_one([], current, _), do: [current] + defp take_mapped(_sample, 0, acc), do: acc - defp random_index(0), do: 0 - defp random_index(idx), do: :rand.uniform(idx + 1) - 1 + defp take_mapped(sample, position, acc) do + position = position - 1 + take_mapped(sample, position, [Map.fetch!(sample, position) | acc]) + end @doc """ Takes the elements from the beginning of the `enumerable` while `fun` returns @@ -4439,14 +4456,6 @@ defmodule Enum do [acc | scan_list(rest, acc, fun)] end - ## shuffle - - defp shuffle_unwrap([{_, h} | enumerable], t) do - shuffle_unwrap(enumerable, [h | t]) - end - - defp shuffle_unwrap([], t), do: t - ## slice defp slice_forward(enumerable, start, amount, step) when start < 0 do diff --git a/lib/elixir/test/elixir/enum_test.exs b/lib/elixir/test/elixir/enum_test.exs index ebb4fd31889..eee7d772cdc 100644 --- a/lib/elixir/test/elixir/enum_test.exs +++ b/lib/elixir/test/elixir/enum_test.exs @@ -1008,7 +1008,7 @@ defmodule EnumTest do test "shuffle/1" do # set a fixed seed so the test can be deterministic :rand.seed(:exsss, {1374, 347_975, 449_264}) - assert Enum.shuffle([1, 2, 3, 4, 5]) == [1, 3, 4, 5, 2] + assert Enum.shuffle([1, 2, 3, 4, 5]) == [2, 5, 4, 3, 1] end test "slice/2" do @@ -1377,16 +1377,16 @@ defmodule EnumTest do seed1 = {1406, 407_414, 139_258} seed2 = {1406, 421_106, 567_597} :rand.seed(:exsss, seed1) - assert Enum.take_random([1, 2, 3], 1) == [3] - assert Enum.take_random([1, 2, 3], 2) == [3, 2] + assert Enum.take_random([1, 2, 3], 1) == [2] + assert Enum.take_random([1, 2, 3], 2) == [2, 3] assert Enum.take_random([1, 2, 3], 3) == [3, 1, 2] - assert Enum.take_random([1, 2, 3], 4) == [1, 3, 2] + assert Enum.take_random([1, 2, 3], 4) == [2, 3, 1] :rand.seed(:exsss, seed2) assert Enum.take_random([1, 2, 3], 1) == [1] assert Enum.take_random([1, 2, 3], 2) == [3, 1] - assert Enum.take_random([1, 2, 3], 3) == [3, 1, 2] - assert Enum.take_random([1, 2, 3], 4) == [2, 1, 3] - assert Enum.take_random([1, 2, 3], 129) == [2, 3, 1] + assert Enum.take_random([1, 2, 3], 3) == [2, 3, 1] + assert Enum.take_random([1, 2, 3], 4) == [3, 2, 1] + assert Enum.take_random([1, 2, 3], 129) == [2, 1, 3] # assert that every item in the sample comes from the input list list = for _ <- 1..100, do: make_ref() @@ -2071,8 +2071,8 @@ defmodule EnumTest.Range do test "shuffle/1" do # set a fixed seed so the test can be deterministic :rand.seed(:exsss, {1374, 347_975, 449_264}) - assert Enum.shuffle(1..5) == [1, 3, 4, 5, 2] - assert Enum.shuffle(1..10//2) == [3, 9, 7, 1, 5] + assert Enum.shuffle(1..5) == [2, 5, 4, 3, 1] + assert Enum.shuffle(1..10//2) == [5, 1, 7, 9, 3] end test "slice/2" do @@ -2316,7 +2316,7 @@ defmodule EnumTest.Range do seed1 = {1406, 407_414, 139_258} seed2 = {1406, 421_106, 567_597} :rand.seed(:exsss, seed1) - assert Enum.take_random(1..3, 1) == [3] + assert Enum.take_random(1..3, 1) == [2] :rand.seed(:exsss, seed1) assert Enum.take_random(1..3, 2) == [3, 1] :rand.seed(:exsss, seed1) @@ -2324,22 +2324,19 @@ defmodule EnumTest.Range do :rand.seed(:exsss, seed1) assert Enum.take_random(1..3, 4) == [3, 1, 2] :rand.seed(:exsss, seed1) - assert Enum.take_random(3..1//-1, 1) == [1] + assert Enum.take_random(1..3, 5) == [3, 1, 2] + :rand.seed(:exsss, seed1) + assert Enum.take_random(3..1//-1, 1) == [2] :rand.seed(:exsss, seed2) assert Enum.take_random(1..3, 1) == [1] :rand.seed(:exsss, seed2) - assert Enum.take_random(1..3, 2) == [1, 3] + assert Enum.take_random(1..3, 2) == [3, 2] :rand.seed(:exsss, seed2) assert Enum.take_random(1..3, 3) == [1, 3, 2] :rand.seed(:exsss, seed2) assert Enum.take_random(1..3, 4) == [1, 3, 2] - - # make sure optimizations don't change fixed seeded tests - :rand.seed(:exsss, {101, 102, 103}) - one = Enum.take_random(1..100, 1) - :rand.seed(:exsss, {101, 102, 103}) - two = Enum.take_random(1..100, 2) - assert hd(one) == hd(two) + :rand.seed(:exsss, seed2) + assert Enum.take_random(1..3, 5) == [1, 3, 2] end test "take_while/2" do @@ -2425,7 +2422,7 @@ defmodule EnumTest.Map do seed1 = {1406, 407_414, 139_258} seed2 = {1406, 421_106, 567_597} :rand.seed(:exsss, seed1) - assert Enum.take_random(map, 1) == [x3] + assert Enum.take_random(map, 1) == [x2] :rand.seed(:exsss, seed1) assert Enum.take_random(map, 2) == [x3, x1] :rand.seed(:exsss, seed1) @@ -2435,7 +2432,7 @@ defmodule EnumTest.Map do :rand.seed(:exsss, seed2) assert Enum.take_random(map, 1) == [x1] :rand.seed(:exsss, seed2) - assert Enum.take_random(map, 2) == [x1, x3] + assert Enum.take_random(map, 2) == [x3, x2] :rand.seed(:exsss, seed2) assert Enum.take_random(map, 3) == [x1, x3, x2] :rand.seed(:exsss, seed2)