Skip to content

Commit c8bba84

Browse files
authored
Prevent accidentally setting remember token (#899)
1 parent 019dbe3 commit c8bba84

File tree

6 files changed

+30
-41
lines changed

6 files changed

+30
-41
lines changed

.editorconfig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ insert_final_newline = true
77
trim_trailing_whitespace = true
88
end_of_line = lf
99
charset = utf-8
10-
max_line_length = 88
10+
max_line_length = 100
1111

1212
[*.{yml,yaml,json,js,css,html}]
1313
indent_size = 2

.readthedocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ python:
99
- method: pip
1010
path: .
1111
sphinx:
12+
configuration: docs/conf.py
1213
builder: dirhtml
1314
fail_on_warning: true

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ src = ["src"]
6363
fix = true
6464
show-fixes = true
6565
show-source = true
66+
line-length = 100
6667

6768
[tool.ruff.lint]
6869
select = [

src/flask_login/login_manager.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def _load_user(self):
326326
if user is None:
327327
config = current_app.config
328328
cookie_name = config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
329-
has_cookie = (
330-
cookie_name in request.cookies and session.get("_remember") != "clear"
331-
)
329+
has_cookie = cookie_name in request.cookies and session.get("_remember") != "clear"
332330
if has_cookie:
333331
cookie = request.cookies[cookie_name]
334332
user = self._load_user_from_remember_cookie(cookie)
@@ -389,23 +387,27 @@ def _load_user_from_request(self, request):
389387
return None
390388

391389
def _update_remember_cookie(self, response):
392-
# Don't modify the session unless there's something to do.
393-
if "_remember" not in session and current_app.config.get(
394-
"REMEMBER_COOKIE_REFRESH_EACH_REQUEST"
395-
):
396-
session["_remember"] = "set"
390+
config = current_app.config
391+
cookie_name = config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
392+
has_cookie = cookie_name in request.cookies and session.get("_remember") != "clear"
393+
refresh = current_app.config.get("REMEMBER_COOKIE_REFRESH_EACH_REQUEST") and has_cookie
397394

398-
if "_remember" in session:
399-
operation = session.pop("_remember", None)
395+
operation = session.pop("_remember", None)
396+
if not operation and not refresh:
397+
return response
400398

401-
if operation == "set" and "_user_id" in session:
402-
self._set_cookie(response)
403-
elif operation == "clear":
404-
self._clear_cookie(response)
399+
if operation == "clear":
400+
self._clear_cookie(response)
401+
402+
if operation == "set" or refresh:
403+
self._set_cookie(response)
405404

406405
return response
407406

408407
def _set_cookie(self, response):
408+
if "_user_id" not in session:
409+
return
410+
409411
# cookie settings
410412
config = current_app.config
411413
cookie_name = config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME)
@@ -431,8 +433,7 @@ def _set_cookie(self, response):
431433
expires = datetime.now(timezone.utc) + duration
432434
except TypeError as e:
433435
raise Exception(
434-
"REMEMBER_COOKIE_DURATION must be a datetime.timedelta,"
435-
f" instead got: {duration}"
436+
f"REMEMBER_COOKIE_DURATION must be a datetime.timedelta, instead got: {duration}"
436437
) from e
437438

438439
# actually set it

src/flask_login/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,7 @@ def login_url(login_view, next_url=None, next_field="next"):
126126
md = parse_qs(parsed_result.query, keep_blank_values=True)
127127
md[next_field] = make_next_param(base, next_url)
128128
netloc = current_app.config.get("FORCE_HOST_FOR_REDIRECTS") or parsed_result.netloc
129-
parsed_result = parsed_result._replace(
130-
netloc=netloc, query=urlencode(md, doseq=True)
131-
)
129+
parsed_result = parsed_result._replace(netloc=netloc, query=urlencode(md, doseq=True))
132130
return urlunsplit(parsed_result)
133131

134132

@@ -191,8 +189,7 @@ def login_user(user, remember=False, duration=None, force=False, fresh=True):
191189
try:
192190
# equal to timedelta.total_seconds() but works with Python 2.6
193191
session["_remember_seconds"] = (
194-
duration.microseconds
195-
+ (duration.seconds + duration.days * 24 * 3600) * 10**6
192+
duration.microseconds + (duration.seconds + duration.days * 24 * 3600) * 10**6
196193
) / 10.0**6
197194
except AttributeError as e:
198195
raise Exception(

tests/test_login.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def test_no_user_loader_raises(self):
185185
class MethodViewLoginTestCase(unittest.TestCase):
186186
def setUp(self):
187187
self.app = Flask(__name__)
188+
self.app.config["SECRET_KEY"] = "deterministic"
188189
self.login_manager = LoginManager()
189190
self.login_manager.init_app(self.app)
190191

@@ -709,9 +710,7 @@ def test_set_cookie_with_invalid_custom_duration_raises_exception(self):
709710
with self.app.test_request_context():
710711
login_user(notch, remember=True, duration="123")
711712

712-
expected_exception_message = (
713-
"duration must be a datetime.timedelta, instead got: 123"
714-
)
713+
expected_exception_message = "duration must be a datetime.timedelta, instead got: 123"
715714
self.assertIn(expected_exception_message, str(cm.exception))
716715

717716
def test_remember_me_no_refresh_every_request(self):
@@ -732,7 +731,7 @@ def test_remember_me_refresh_each_request(self):
732731
now = datetime.now(timezone.utc)
733732
mock_dt.now = Mock(return_value=now)
734733

735-
domain = self.app.config["REMEMBER_COOKIE_DOMAIN"] = "localhost.local"
734+
domain = self.app.config["REMEMBER_COOKIE_DOMAIN"] = "localhost"
736735
path = self.app.config["REMEMBER_COOKIE_PATH"] = "/"
737736
self.app.config["REMEMBER_COOKIE_REFRESH_EACH_REQUEST"] = True
738737
c = self.app.test_client()
@@ -886,9 +885,7 @@ def test_user_login_confirmed_signal_fired(self):
886885
def test_session_not_modified(self):
887886
with self.app.test_client() as c:
888887
# Within the request we think we didn't modify the session.
889-
self.assertEqual(
890-
"modified=False", c.get("/empty_session").data.decode("utf-8")
891-
)
888+
self.assertEqual("modified=False", c.get("/empty_session").data.decode("utf-8"))
892889
# But after the request, the session could be modified by the
893890
# "after_request" handlers that call _update_remember_cookie.
894891
# Ensure that if nothing changed the session is not modified.
@@ -1260,25 +1257,19 @@ def test_make_next_param(self):
12601257
url = make_next_param("https://localhost/login", "http://localhost/profile")
12611258
self.assertEqual("http://localhost/profile", url)
12621259

1263-
url = make_next_param(
1264-
"http://accounts.localhost/login", "http://localhost/profile"
1265-
)
1260+
url = make_next_param("http://accounts.localhost/login", "http://localhost/profile")
12661261
self.assertEqual("http://localhost/profile", url)
12671262

12681263
def test_login_url_generation(self):
12691264
with self.app.test_request_context():
12701265
PROTECTED = "http://localhost/protected"
12711266

1272-
self.assertEqual(
1273-
"/login?n=%2Fprotected", login_url("/login", PROTECTED, "n")
1274-
)
1267+
self.assertEqual("/login?n=%2Fprotected", login_url("/login", PROTECTED, "n"))
12751268

12761269
url = login_url("/login", PROTECTED)
12771270
self.assertEqual("/login?next=%2Fprotected", url)
12781271

1279-
expected = (
1280-
"https://auth.localhost/login?next=http%3A%2F%2Flocalhost%2Fprotected"
1281-
)
1272+
expected = "https://auth.localhost/login?next=http%3A%2F%2Flocalhost%2Fprotected"
12821273
result = login_url("https://auth.localhost/login", PROTECTED)
12831274
self.assertEqual(expected, result)
12841275

@@ -1290,9 +1281,7 @@ def test_login_url_generation(self):
12901281

12911282
def test_login_url_generation_with_view(self):
12921283
with self.app.test_request_context():
1293-
self.assertEqual(
1294-
"/login?next=%2Fprotected", login_url("login", "/protected")
1295-
)
1284+
self.assertEqual("/login?next=%2Fprotected", login_url("login", "/protected"))
12961285

12971286
def test_login_url_no_next_url(self):
12981287
self.assertEqual(login_url("/foo"), "/foo")

0 commit comments

Comments
 (0)