-
Notifications
You must be signed in to change notification settings - Fork 174
/
views.py
147 lines (116 loc) · 4.76 KB
/
views.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from contextlib import contextmanager
from django.contrib.auth import BACKEND_SESSION_KEY, get_user_model, load_backend, login
from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin
from django.db import ConnectionRouter, transaction
from django.http import HttpResponseBadRequest, HttpResponseRedirect
from django.shortcuts import get_object_or_404, resolve_url
from django.utils.decorators import method_decorator
from django.utils.http import url_has_allowed_host_and_scheme
from django.utils.module_loading import import_string
from django.views import View
from django.views.decorators.csrf import csrf_protect
from django.views.generic.detail import SingleObjectMixin
from hijack import signals
from hijack.conf import settings
def get_used_backend(request):
backend_str = request.session[BACKEND_SESSION_KEY]
return load_backend(backend_str)
@contextmanager
def keep_session_age(session):
try:
session_expiry = session["_session_expiry"]
except KeyError:
yield
else:
yield
session["_session_expiry"] = session_expiry
class SuccessUrlMixin:
redirect_field_name = "next"
success_url = "/"
def get_success_url(self):
url = self.get_redirect_url()
return url or resolve_url(self.success_url or "/")
def get_redirect_url(self):
"""Return the user-originating redirect URL if it's safe."""
redirect_to = self.request.POST.get(
self.redirect_field_name, self.request.GET.get(self.redirect_field_name, "")
)
url_is_safe = url_has_allowed_host_and_scheme(
url=redirect_to,
allowed_hosts=self.request.get_host(),
require_https=self.request.is_secure(),
)
return redirect_to if url_is_safe else ""
class LockUserTableMixin:
def dispatch(self, request, *args, **kwargs):
write_db = ConnectionRouter().db_for_write(get_user_model())
# Lock user row to avoid race conditions
with transaction.atomic(using=write_db):
get_user_model()._base_manager.select_for_update().get(pk=request.user.pk)
return super().dispatch(request, *args, **kwargs)
class AcquireUserView(
LoginRequiredMixin,
LockUserTableMixin,
UserPassesTestMixin,
SuccessUrlMixin,
SingleObjectMixin,
View,
):
model = get_user_model()
success_url = settings.LOGIN_REDIRECT_URL
def test_func(self):
func = import_string(settings.HIJACK_PERMISSION_CHECK)
return func(hijacker=self.request.user, hijacked=self.get_object())
def get_object(self, queryset=None):
return get_object_or_404(self.model, pk=self.request.POST["user_pk"])
def dispatch(self, request, *args, **kwargs):
if "user_pk" not in self.request.POST:
return HttpResponseBadRequest()
return super().dispatch(request, *args, **kwargs)
@method_decorator(csrf_protect)
def post(self, request, *args, **kwargs):
hijacker = request.user
hijacked = self.get_object()
hijack_history = request.session.get("hijack_history", [])
hijack_history.append(request.user._meta.pk.value_to_string(hijacker))
backend = get_used_backend(request)
backend = f"{backend.__module__}.{backend.__class__.__name__}"
with signals.no_update_last_login(), keep_session_age(request.session):
login(request, hijacked, backend=backend)
request.session["hijack_history"] = hijack_history
signals.hijack_started.send(
sender=None,
request=request,
hijacker=hijacker,
hijacked=hijacked,
)
return HttpResponseRedirect(self.get_success_url())
class ReleaseUserView(
LoginRequiredMixin,
LockUserTableMixin,
UserPassesTestMixin,
SuccessUrlMixin,
View,
):
raise_exception = True
success_url = settings.LOGOUT_REDIRECT_URL
def test_func(self):
return bool(self.request.session.get("hijack_history", []))
@method_decorator(csrf_protect)
def post(self, request, *args, **kwargs):
hijack_history = request.session.get("hijack_history", [])
hijacked = request.user
user_pk = hijack_history.pop()
hijacker = get_object_or_404(get_user_model(), pk=user_pk)
backend = get_used_backend(request)
backend = f"{backend.__module__}.{backend.__class__.__name__}"
with signals.no_update_last_login(), keep_session_age(request.session):
login(request, hijacker, backend=backend)
request.session["hijack_history"] = hijack_history
signals.hijack_ended.send(
sender=None,
request=request,
hijacker=hijacker,
hijacked=hijacked,
)
return HttpResponseRedirect(self.get_success_url())