Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed #34901 -- Added async-compatible interface to session engines. #17372

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion django/contrib/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,6 @@ def update_session_auth_hash(request, user):

async def aupdate_session_auth_hash(request, user):
"""See update_session_auth_hash()."""
return await sync_to_async(update_session_auth_hash)(request, user)
await request.session.acycle_key()
if hasattr(user, "get_session_auth_hash") and request.user == user:
await request.session.aset(HASH_SESSION_KEY, user.get_session_auth_hash())
159 changes: 159 additions & 0 deletions django/contrib/sessions/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import string
from datetime import datetime, timedelta

from asgiref.sync import sync_to_async

from django.conf import settings
from django.core import signing
from django.utils import timezone
Expand Down Expand Up @@ -56,6 +58,10 @@ def __setitem__(self, key, value):
self._session[key] = value
self.modified = True

async def aset(self, key, value):
(await self._aget_session())[key] = value
self.modified = True

def __delitem__(self, key):
del self._session[key]
self.modified = True
Expand All @@ -67,27 +73,52 @@ def key_salt(self):
def get(self, key, default=None):
return self._session.get(key, default)

async def aget(self, key, default=None):
return (await self._aget_session()).get(key, default)

def pop(self, key, default=__not_given):
self.modified = self.modified or key in self._session
args = () if default is self.__not_given else (default,)
return self._session.pop(key, *args)

async def apop(self, key, default=__not_given):
self.modified = self.modified or key in (await self._aget_session())
args = () if default is self.__not_given else (default,)
return (await self._aget_session()).pop(key, *args)

def setdefault(self, key, value):
if key in self._session:
return self._session[key]
else:
self[key] = value
return value

async def asetdefault(self, key, value):
session = await self._aget_session()
if key in session:
return session[key]
else:
await self.aset(key, value)
return value

def set_test_cookie(self):
self[self.TEST_COOKIE_NAME] = self.TEST_COOKIE_VALUE

async def aset_test_cookie(self):
await self.aset(self.TEST_COOKIE_NAME, self.TEST_COOKIE_VALUE)
bigfootjon marked this conversation as resolved.
Show resolved Hide resolved

def test_cookie_worked(self):
return self.get(self.TEST_COOKIE_NAME) == self.TEST_COOKIE_VALUE

async def atest_cookie_worked(self):
return (await self.aget(self.TEST_COOKIE_NAME)) == self.TEST_COOKIE_VALUE

def delete_test_cookie(self):
del self[self.TEST_COOKIE_NAME]

async def adelete_test_cookie(self):
del (await self._aget_session())[self.TEST_COOKIE_NAME]

def encode(self, session_dict):
"Return the given session dictionary serialized and encoded as a string."
return signing.dumps(
Expand Down Expand Up @@ -115,18 +146,34 @@ def update(self, dict_):
self._session.update(dict_)
self.modified = True

async def aupdate(self, dict_):
(await self._aget_session()).update(dict_)
self.modified = True

def has_key(self, key):
return key in self._session

async def ahas_key(self, key):
return key in (await self._aget_session())

def keys(self):
return self._session.keys()

async def akeys(self):
return (await self._aget_session()).keys()

def values(self):
return self._session.values()

async def avalues(self):
return (await self._aget_session()).values()

def items(self):
return self._session.items()

async def aitems(self):
return (await self._aget_session()).items()

def clear(self):
# To avoid unnecessary persistent storage accesses, we set up the
# internals directly (loading data wastes time, since we are going to
Expand All @@ -149,11 +196,22 @@ def _get_new_session_key(self):
if not self.exists(session_key):
return session_key

async def _aget_new_session_key(self):
while True:
session_key = get_random_string(32, VALID_KEY_CHARS)
if not await self.aexists(session_key):
return session_key

def _get_or_create_session_key(self):
if self._session_key is None:
self._session_key = self._get_new_session_key()
return self._session_key

async def _aget_or_create_session_key(self):
if self._session_key is None:
self._session_key = await self._aget_new_session_key()
return self._session_key

def _validate_session_key(self, key):
"""
Key must be truthy and at least 8 characters long. 8 characters is an
Expand Down Expand Up @@ -191,6 +249,17 @@ def _get_session(self, no_load=False):
self._session_cache = self.load()
return self._session_cache

async def _aget_session(self, no_load=False):
self.accessed = True
try:
return self._session_cache
except AttributeError:
if self.session_key is None or no_load:
self._session_cache = {}
else:
self._session_cache = await self.aload()
return self._session_cache

_session = property(_get_session)

def get_session_cookie_age(self):
Expand Down Expand Up @@ -223,6 +292,25 @@ def get_expiry_age(self, **kwargs):
delta = expiry - modification
return delta.days * 86400 + delta.seconds

async def aget_expiry_age(self, **kwargs):
try:
modification = kwargs["modification"]
except KeyError:
modification = timezone.now()
try:
expiry = kwargs["expiry"]
except KeyError:
expiry = await self.aget("_session_expiry")

if not expiry: # Checks both None and 0 cases
return self.get_session_cookie_age()
if not isinstance(expiry, (datetime, str)):
return expiry
if isinstance(expiry, str):
expiry = datetime.fromisoformat(expiry)
delta = expiry - modification
return delta.days * 86400 + delta.seconds

def get_expiry_date(self, **kwargs):
"""Get session the expiry date (as a datetime object).

Expand All @@ -246,6 +334,23 @@ def get_expiry_date(self, **kwargs):
expiry = expiry or self.get_session_cookie_age()
return modification + timedelta(seconds=expiry)

async def aget_expiry_date(self, **kwargs):
try:
modification = kwargs["modification"]
except KeyError:
modification = timezone.now()
try:
expiry = kwargs["expiry"]
except KeyError:
expiry = await self.aget("_session_expiry")

if isinstance(expiry, datetime):
return expiry
elif isinstance(expiry, str):
return datetime.fromisoformat(expiry)
expiry = expiry or self.get_session_cookie_age()
return modification + timedelta(seconds=expiry)

def set_expiry(self, value):
"""
Set a custom expiration for the session. ``value`` can be an integer,
Expand Down Expand Up @@ -274,6 +379,20 @@ def set_expiry(self, value):
value = value.isoformat()
self["_session_expiry"] = value

async def aset_expiry(self, value):
if value is None:
# Remove any custom expiration for this session.
try:
await self.apop("_session_expiry")
felixxm marked this conversation as resolved.
Show resolved Hide resolved
except KeyError:
pass
return
if isinstance(value, timedelta):
value = timezone.now() + value
if isinstance(value, datetime):
value = value.isoformat()
await self.aset("_session_expiry", value)

def get_expire_at_browser_close(self):
"""
Return ``True`` if the session is set to expire when the browser
Expand All @@ -285,6 +404,11 @@ def get_expire_at_browser_close(self):
return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
return expiry == 0

async def aget_expire_at_browser_close(self):
if (expiry := await self.aget("_session_expiry")) is None:
return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
return expiry == 0

def flush(self):
"""
Remove the current session data from the database and regenerate the
Expand All @@ -294,6 +418,11 @@ def flush(self):
self.delete()
self._session_key = None

async def aflush(self):
self.clear()
await self.adelete()
self._session_key = None

def cycle_key(self):
"""
Create a new session key, while retaining the current session data.
Expand All @@ -305,6 +434,17 @@ def cycle_key(self):
if key:
self.delete(key)

async def acycle_key(self):
"""
Create a new session key, while retaining the current session data.
"""
data = await self._aget_session()
key = self.session_key
await self.acreate()
self._session_cache = data
if key:
await self.adelete(key)

# Methods that child classes must implement.

def exists(self, session_key):
Expand All @@ -315,6 +455,9 @@ def exists(self, session_key):
"subclasses of SessionBase must provide an exists() method"
)

async def aexists(self, session_key):
return await sync_to_async(self.exists)(session_key)

def create(self):
"""
Create a new session instance. Guaranteed to create a new object with
Expand All @@ -325,6 +468,9 @@ def create(self):
"subclasses of SessionBase must provide a create() method"
)

async def acreate(self):
return await sync_to_async(self.create)()

def save(self, must_create=False):
"""
Save the session data. If 'must_create' is True, create a new session
Expand All @@ -335,6 +481,9 @@ def save(self, must_create=False):
"subclasses of SessionBase must provide a save() method"
)

async def asave(self, must_create=False):
return await sync_to_async(self.save)(must_create)

def delete(self, session_key=None):
"""
Delete the session data under this key. If the key is None, use the
Expand All @@ -344,6 +493,9 @@ def delete(self, session_key=None):
"subclasses of SessionBase must provide a delete() method"
)

async def adelete(self, session_key=None):
return await sync_to_async(self.delete)(session_key)

def load(self):
"""
Load the session data and return a dictionary.
Expand All @@ -352,6 +504,9 @@ def load(self):
"subclasses of SessionBase must provide a load() method"
)

async def aload(self):
return await sync_to_async(self.load)()

@classmethod
def clear_expired(cls):
"""
Expand All @@ -362,3 +517,7 @@ def clear_expired(cls):
a built-in expiration mechanism, it should be a no-op.
"""
raise NotImplementedError("This backend does not support clear_expired().")

@classmethod
async def aclear_expired(cls):
return await sync_to_async(cls.clear_expired)()
Loading