Skip to content

Commit

Permalink
feat: added sid in oidc session and methods to get revocation easier
Browse files Browse the repository at this point in the history
  • Loading branch information
peppelinux committed Mar 7, 2022
1 parent f233568 commit 17530d0
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 15 deletions.
6 changes: 5 additions & 1 deletion spid_cie_oidc/accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,8 @@ def clear_sessions(self): # pragma: no cover
return Session.objects.filter(pk__in=user_sessions).delete()

def __str__(self): # pragma: no cover
return "{} {}".format(self.first_name, self.last_name)
return "{} - {} {}".format(
self.username,
self.first_name or self.attributes.get('given_name'),
self.last_name or self.attributes.get('family_name', "")
)
1 change: 1 addition & 0 deletions spid_cie_oidc/provider/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class OidcSessionAdmin(admin.ModelAdmin):
"auth_code",
"user_uid",
"user",
"sid",
"client_id",
"created",
"revoked",
Expand Down
27 changes: 26 additions & 1 deletion spid_cie_oidc/provider/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from django.contrib.auth import get_user_model
from django.contrib.sessions.models import Session
from django.db import models
from django.utils import timezone
from django.utils.translation import gettext as _
from spid_cie_oidc.entity.abstract_models import TimeStampedModel

import hashlib
import logging

from spid_cie_oidc.provider.settings import OIDCFED_PROVIDER_SALT

logger = logging.getLogger(__name__)


class OidcSession(TimeStampedModel):
"""
Expand All @@ -18,13 +23,33 @@ class OidcSession(TimeStampedModel):
get_user_model(), on_delete=models.SET_NULL, blank=True, null=True
)
client_id = models.URLField(blank=True, null=True)

sid = models.CharField(
max_length=1024, blank=True, null=True,
help_text=_("django session key")
)
nonce = models.CharField(max_length=2048, blank=False, null=False)
authz_request = models.JSONField(max_length=2048, blank=False, null=False)

revoked = models.BooleanField(default=False)
auth_code = models.CharField(max_length=2048, blank=False, null=False)

def set_sid(self, request):
try:
Session.objects.get(session_key=request.session.session_key)
self.sid = request.session.session_key
self.save()
except Exception:
logger.warning(f"Error setting SID for OidcSession {self}")

def revoke(self):
session = Session.objects.filter(session_key=self.sid)
if session:
session.delete()
self.revoked = True
iss_tokens = IssuedToken.objects.filter(session=self)
iss_tokens.update(revoked=True)
self.save()

def pairwised_sub(self):
return hashlib.sha256(
f"{self.user_uid}{self.client_id}{OIDCFED_PROVIDER_SALT}".encode()
Expand Down
30 changes: 17 additions & 13 deletions spid_cie_oidc/provider/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def validate_authz_request_object(self, req) -> TrustChain:

def is_a_replay_authz(self):
preexistent_authz = OidcSession.objects.filter(
client_id=self.payload["client_id"],
client_id=self.payload["client_id"],
nonce=self.payload["nonce"]
).first()
if preexistent_authz:
Expand All @@ -155,7 +155,7 @@ def check_session(self, request) -> OidcSession:
raise InvalidSession()

session = OidcSession.objects.filter(
auth_code=request.session["oidc"]["auth_code"],
auth_code=request.session["oidc"]["auth_code"],
user=request.user
).first()

Expand All @@ -178,13 +178,14 @@ def check_client_assertion(self, client_id: str, client_assertion: str) -> bool:
if payload['sub'] != client_id:
# TODO Specialize exceptions
raise Exception()

tc = TrustChain.objects.get(sub=client_id, is_active=True)
jwk = self.find_jwk(head, tc.metadata['jwks']['keys'])
verify_jws(client_assertion, jwk)

return True


class AuthzRequestView(OpBase, View):
"""
View which processes the actual Authz request and
Expand Down Expand Up @@ -340,14 +341,16 @@ def post(self, request, *args, **kwargs):
request.session["oidc"] = {"auth_code": auth_code}

# store the User session
OidcSession.objects.create(
session = OidcSession.objects.create(
user=user,
user_uid=user.username,
nonce=self.payload["nonce"],
authz_request=self.payload,
client_id=self.payload["client_id"],
auth_code=auth_code,
)
session.set_sid(request)

consent_url = reverse("oidc_provider_consent")
return HttpResponseRedirect(consent_url)

Expand Down Expand Up @@ -428,8 +431,8 @@ def post(self, request, *args, **kwargs):
class TokenEndpoint(OpBase, View):
def get_jwt_common_data(self):
return {
"jti": str(uuid.uuid4()),
"exp": exp_from_now(),
"jti": str(uuid.uuid4()),
"exp": exp_from_now(),
"iat": iat_now()
}

Expand Down Expand Up @@ -501,7 +504,7 @@ def grant_auth_code(self, request, *args, **kwargs):

# refresh token is scope offline_access and prompt == consent
if (
"offline_access" in self.authz.authz_request['scope'] and
"offline_access" in self.authz.authz_request['scope'] and
'consent' in self.authz.authz_request['prompt']
):
refresh_token = {
Expand All @@ -515,7 +518,7 @@ def grant_auth_code(self, request, *args, **kwargs):
iss_token_data['refresh_token'] = refresh_token

IssuedToken.objects.create(**iss_token_data)

expires_in = timezone.timedelta(
seconds = access_token['exp'] - access_token['iat']
).seconds
Expand Down Expand Up @@ -556,7 +559,7 @@ def post(self, request, *args, **kwargs):
self.issuer = self.get_issuer()

self.authz = OidcSession.objects.filter(
auth_code=request.POST["code"],
auth_code=request.POST["code"],
revoked=False
).first()

Expand All @@ -566,7 +569,7 @@ def post(self, request, *args, **kwargs):
# check client_assertion and client ownership
try:
self.check_client_assertion(
request.POST['client_id'],
request.POST['client_id'],
request.POST['client_assertion']
)
except Exception:
Expand All @@ -576,7 +579,7 @@ def post(self, request, *args, **kwargs):
{
'error': "...",
'error_description': "..."

}, status = 403
)

Expand All @@ -587,6 +590,7 @@ def post(self, request, *args, **kwargs):
else:
raise NotImplementedError()


class UserInfoEndpoint(OpBase, View):
def get(self, request, *args, **kwargs):

Expand All @@ -606,8 +610,8 @@ def get(self, request, *args, **kwargs):
return HttpResponseForbidden()

rp_tc = TrustChain.objects.filter(
sub=token.session.client_id,
type="openid_relying_party",
sub=token.session.client_id,
type="openid_relying_party",
is_active=True
).first()
if not rp_tc:
Expand Down

0 comments on commit 17530d0

Please sign in to comment.