diff --git a/django_pandas/utils.py b/django_pandas/utils.py index 8bd4b50..0540cd8 100644 --- a/django_pandas/utils.py +++ b/django_pandas/utils.py @@ -1,6 +1,5 @@ # coding: utf-8 -from math import isnan from django.core.cache import cache from django.utils.encoding import force_text @@ -31,19 +30,27 @@ def invalidate_signal_handler(sender, **kwargs): def replace_pk(model): base_cache_key = get_base_cache_key(model) - def inner(pk_list): - cache_keys = [None if isnan(pk) else base_cache_key % pk - for pk in pk_list] - out_dict = cache.get_many(frozenset(cache_keys)) - try: - return [None if k is None else out_dict[k] for k in cache_keys] - except KeyError: - out_dict = { - base_cache_key % obj.pk: force_text(obj) - for obj in model.objects.filter(pk__in={pk for pk in pk_list - if not isnan(pk)})} + def get_cache_key_from_pk(pk): + return None if pk is None else base_cache_key % pk + + def inner(pk_series): + pk_series = pk_series.where(pk_series.notnull(), None) + cache_keys = pk_series.apply( + get_cache_key_from_pk, convert_dtype=False) + unique_cache_keys = list(filter(None, cache_keys.unique())) + + if not unique_cache_keys: + return pk_series + + out_dict = cache.get_many(unique_cache_keys) + + if len(out_dict) < len(unique_cache_keys): + out_dict = {base_cache_key % obj.pk: force_text(obj) + for obj in model.objects.filter( + pk__in=list(filter(None, pk_series.unique())))} cache.set_many(out_dict) - return list(map(out_dict.get, cache_keys)) + + return list(map(out_dict.get, cache_keys)) return inner