Skip to content

Commit

Permalink
Fix factory selection when Faker has been seeded (#1630)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcurella committed Mar 23, 2022
1 parent 5b6ce48 commit 5128ae6
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 46 deletions.
5 changes: 5 additions & 0 deletions faker/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
mod_random = random # compat with name released in 0.8


Sentinel = object()


class Generator:

__config: Dict[str, Dict[Hashable, Any]] = {
"arguments": {},
}

_is_seeded = False
_global_seed = Sentinel

def __init__(self, **config: Dict) -> None:
self.providers: List["BaseProvider"] = []
Expand Down Expand Up @@ -74,6 +78,7 @@ def seed_instance(self, seed: Optional[Hashable] = None) -> "Generator":
@classmethod
def seed(cls, seed: Optional[Hashable] = None) -> None:
random.seed(seed)
cls._global_seed = seed
cls._is_seeded = True

def format(self, formatter: str, *args: Any, **kwargs: Any) -> str:
Expand Down
16 changes: 12 additions & 4 deletions faker/proxy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import functools
import random
import re

from collections import OrderedDict
Expand All @@ -10,7 +9,7 @@
from .config import DEFAULT_LOCALE
from .exceptions import UniquenessException
from .factory import Factory
from .generator import Generator
from .generator import Generator, Sentinel, random
from .utils.distribution import choices_distribution

_UNIQUE_ATTEMPTS = 1000
Expand Down Expand Up @@ -146,18 +145,27 @@ def _select_factory(self, method_name: str) -> Factory:
"""

factories, weights = self._map_provider_method(method_name)

if len(factories) == 0:
msg = f"No generator object has attribute {method_name!r}"
raise AttributeError(msg)
elif len(factories) == 1:
return factories[0]

if Generator._global_seed is not Sentinel:
random.seed(Generator._global_seed)
if weights:
factory = choices_distribution(factories, weights, length=1)[0]
factory = self._select_factory_distribution(factories, weights)
else:
factory = random.choice(factories)
factory = self._select_factory_choice(factories)
return factory

def _select_factory_distribution(self, factories, weights):
return choices_distribution(factories, weights, random, length=1)[0]

def _select_factory_choice(self, factories):
return random.choice(factories)

def _map_provider_method(self, method_name: str) -> Tuple[List[Factory], Optional[List[float]]]:
"""
Creates a 2-tuple of factories and weights for the given provider method name
Expand Down
107 changes: 65 additions & 42 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def test_seed_classmethod(self):
Faker.seed(0)
mock_seed.assert_called_once_with(0)

def test_seed_class_locales(self):
Faker.seed(2043)
fake = Faker(["en_GB", "fr_FR", "en_IN"])
name = fake.name()

for _ in range(5):
assert fake.name() == name

def test_seed_instance(self):
locale = ["de_DE", "en-US", "en-PH", "ja_JP"]
fake = Faker(locale)
Expand Down Expand Up @@ -232,51 +240,48 @@ def test_multiple_locale_caching_behavior(self):
# So each call to name() accesses the cached mapping twice
assert mock_cached_map.call_count == 200

@patch("faker.proxy.random.choice")
@patch("faker.proxy.choices_distribution")
def test_multiple_locale_factory_selection_no_weights(self, mock_choices_fn, mock_random_choice):
@patch("faker.proxy.Faker._select_factory_choice")
@patch("faker.proxy.Faker._select_factory_distribution")
def test_multiple_locale_factory_selection_no_weights(self, mock_factory_distribution, mock_factory_choice):
fake = Faker(["de_DE", "en-US", "en-PH", "ja_JP"])

# There are no distribution weights, so factory selection logic will use `random.choice`
# if multiple factories have the specified provider method
with patch("faker.proxy.Faker._select_factory", wraps=fake._select_factory) as mock_select_factory:
mock_select_factory.assert_not_called()
mock_choices_fn.assert_not_called()
mock_random_choice.assert_not_called()
mock_factory_distribution.assert_not_called()
mock_factory_choice.assert_not_called()

# All factories for the listed locales have the `name` provider method
fake.name()
mock_select_factory.assert_called_once_with("name")
mock_choices_fn.assert_not_called()
mock_random_choice.assert_called_once_with(fake.factories)
mock_factory_distribution.assert_not_called()
mock_factory_choice.assert_called_once_with(fake.factories)
mock_select_factory.reset_mock()
mock_choices_fn.reset_mock()
mock_random_choice.reset_mock()
mock_factory_distribution.reset_mock()
mock_factory_choice.reset_mock()

# Only `en_PH` factory has provider method `luzon_province`, so there is no
# need for `random.choice` factory selection logic to run
fake.luzon_province()
mock_select_factory.assert_called_with("luzon_province")
mock_choices_fn.assert_not_called()
mock_random_choice.assert_not_called()
mock_factory_distribution.assert_not_called()
mock_factory_choice.assert_not_called()
mock_select_factory.reset_mock()
mock_choices_fn.reset_mock()
mock_random_choice.reset_mock()
mock_factory_distribution.reset_mock()
mock_factory_choice.reset_mock()

# Both `en_US` and `ja_JP` factories have provider method `zipcode`
fake.zipcode()
mock_select_factory.assert_called_once_with("zipcode")
mock_choices_fn.assert_not_called()
mock_random_choice.assert_called_once_with(
mock_factory_distribution.assert_not_called()
mock_factory_choice.assert_called_once_with(
[fake["en_US"], fake["ja_JP"]],
)
mock_select_factory.reset_mock()
mock_choices_fn.reset_mock()
mock_random_choice.reset_mock()

@patch("faker.proxy.random.choice")
@patch("faker.proxy.choices_distribution")
def test_multiple_locale_factory_selection_with_weights(self, mock_choices_fn, mock_random_choice):
@patch("faker.proxy.Faker._select_factory_choice")
@patch("faker.proxy.Faker._select_factory_distribution")
def test_multiple_locale_factory_selection_with_weights(self, mock_factory_distribution, mock_factory_choice):
locale = OrderedDict(
[
("de_DE", 3),
Expand All @@ -286,8 +291,8 @@ def test_multiple_locale_factory_selection_with_weights(self, mock_choices_fn, m
]
)
fake = Faker(locale)
mock_choices_fn.assert_not_called()
mock_random_choice.assert_not_called()
mock_factory_distribution.assert_not_called()
mock_factory_choice.assert_not_called()

# Distribution weights have been specified, so factory selection logic will use
# `choices_distribution` if multiple factories have the specified provider method
Expand All @@ -296,34 +301,52 @@ def test_multiple_locale_factory_selection_with_weights(self, mock_choices_fn, m
# All factories for the listed locales have the `name` provider method
fake.name()
mock_select_factory.assert_called_once_with("name")
mock_choices_fn.assert_called_once_with(fake.factories, fake.weights, length=1)
mock_random_choice.assert_not_called()
mock_select_factory.reset_mock()
mock_choices_fn.reset_mock()
mock_random_choice.reset_mock()
mock_factory_distribution.assert_called_once_with(fake.factories, fake.weights)
mock_factory_choice.assert_not_called()

@patch("faker.proxy.Faker._select_factory_choice")
@patch("faker.proxy.Faker._select_factory_distribution")
def test_multiple_locale_factory_selection_single_provider(self, mock_factory_distribution, mock_factory_choice):
locale = OrderedDict(
[
("de_DE", 3),
("en-US", 2),
("en-PH", 1),
("ja_JP", 5),
]
)
fake = Faker(locale)

# Distribution weights have been specified, so factory selection logic will use
# `choices_distribution` if multiple factories have the specified provider method
with patch("faker.proxy.Faker._select_factory", wraps=fake._select_factory) as mock_select_factory:

# Only `en_PH` factory has provider method `luzon_province`, so there is no
# need for `choices_distribution` factory selection logic to run
fake.luzon_province()
mock_select_factory.assert_called_once_with("luzon_province")
mock_choices_fn.assert_not_called()
mock_random_choice.assert_not_called()
mock_select_factory.reset_mock()
mock_choices_fn.reset_mock()
mock_random_choice.reset_mock()
mock_factory_distribution.assert_not_called()
mock_factory_choice.assert_not_called()

@patch("faker.proxy.Faker._select_factory_choice")
@patch("faker.proxy.Faker._select_factory_distribution")
def test_multiple_locale_factory_selection_shared_providers(self, mock_factory_distribution, mock_factory_choice):
locale = OrderedDict(
[
("de_DE", 3),
("en-US", 2),
("en-PH", 1),
("ja_JP", 5),
]
)
fake = Faker(locale)

with patch("faker.proxy.Faker._select_factory", wraps=fake._select_factory) as mock_select_factory:
# Both `en_US` and `ja_JP` factories have provider method `zipcode`
fake.zipcode()
mock_select_factory.assert_called_once_with("zipcode")
mock_choices_fn.assert_called_once_with(
[fake["en_US"], fake["ja_JP"]],
[2, 5],
length=1,
)
mock_random_choice.assert_not_called()
mock_select_factory.reset_mock()
mock_choices_fn.reset_mock()
mock_random_choice.reset_mock()
mock_factory_distribution.assert_called_once_with([fake["en_US"], fake["ja_JP"]], [2, 5])
mock_factory_choice.assert_not_called()

def test_multiple_locale_factory_selection_unsupported_method(self):
fake = Faker(["en_US", "en_PH"])
Expand Down

0 comments on commit 5128ae6

Please sign in to comment.