Skip to content

Commit

Permalink
Store the IdP access token and expiry into the session (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
yorinasub17 committed Jul 10, 2023
1 parent 42a0a76 commit cfd1d5b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
19 changes: 15 additions & 4 deletions webstd/chistd/oidc_handlers.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package chistd

import (
"encoding/gob"
"net/http"
"time"

"github.com/alexedwards/scs/v2"
"github.com/go-chi/chi/v5"
Expand All @@ -12,17 +14,24 @@ import (
"github.com/fensak-io/gostd/webstd"
)

func init() {
// Register time.Time to gob so it can be stored in the session
gob.Register(time.Time{})
}

const (
// URL paths
OIDCLoginPath = "/oidc/login"
OIDCLogoutPath = "/oidc/logout"
OIDCCallbackPath = "/oidc/callback"

// Session keys
RefreshTokenSessionKey = "refresh_token"
UserProfileSessionKey = "profile"
ContinueToURLSessionKey = "continue_to"
PKCECodeVerifierSessionKey = "pkce_code_verifier"
AccessTokenSessionKey = "access_token"
AccessTokenExpirySessionKey = "access_token_expiry"
RefreshTokenSessionKey = "refresh_token"
UserProfileSessionKey = "profile"
ContinueToURLSessionKey = "continue_to"
PKCECodeVerifierSessionKey = "pkce_code_verifier"
)

type OIDCHandlerContext[T any] struct {
Expand Down Expand Up @@ -159,6 +168,8 @@ func (h OIDCHandlerContext[T]) oidcCallbackHandler(w http.ResponseWriter, r *htt
return
}

h.sessMgr.Put(ctx, AccessTokenSessionKey, token.AccessToken)
h.sessMgr.Put(ctx, AccessTokenExpirySessionKey, token.Expiry)
h.sessMgr.Put(ctx, RefreshTokenSessionKey, token.RefreshToken)
h.sessMgr.Put(ctx, UserProfileSessionKey, profile)

Expand Down
8 changes: 5 additions & 3 deletions webstd/idp/aadb2c/aadb2c.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,14 @@ func (a AADB2C) GetLogoutURL(ctx context.Context) (string, error) {
}

refreshToken := a.sessMgr.GetString(ctx, chistd.RefreshTokenSessionKey)
rawIDToken, _, newRefreshToken, err := a.auth.RefreshIDToken(ctx, refreshToken)
rawIDToken, _, newToken, err := a.auth.RefreshIDToken(ctx, refreshToken)
if err != nil {
return "", err
}
if newRefreshToken != "" {
a.sessMgr.Put(ctx, chistd.RefreshTokenSessionKey, newRefreshToken)
a.sessMgr.Put(ctx, chistd.AccessTokenSessionKey, newToken.AccessToken)
a.sessMgr.Put(ctx, chistd.AccessTokenExpirySessionKey, newToken.Expiry)
if newToken.RefreshToken != "" {
a.sessMgr.Put(ctx, chistd.RefreshTokenSessionKey, newToken.RefreshToken)
}

appURLCopy := a.appURL
Expand Down
8 changes: 5 additions & 3 deletions webstd/idp/zitadel/zitadel.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,14 @@ func (z Zitadel) GetLogoutURL(ctx context.Context) (string, error) {
}

refreshToken := z.sessMgr.GetString(ctx, chistd.RefreshTokenSessionKey)
rawIDToken, _, newRefreshToken, err := z.auth.RefreshIDToken(ctx, refreshToken)
rawIDToken, _, newToken, err := z.auth.RefreshIDToken(ctx, refreshToken)
if err != nil {
return "", err
}
if newRefreshToken != "" {
z.sessMgr.Put(ctx, chistd.RefreshTokenSessionKey, newRefreshToken)
z.sessMgr.Put(ctx, chistd.AccessTokenSessionKey, newToken.AccessToken)
z.sessMgr.Put(ctx, chistd.AccessTokenExpirySessionKey, newToken.Expiry)
if newToken.RefreshToken != "" {
z.sessMgr.Put(ctx, chistd.RefreshTokenSessionKey, newToken.RefreshToken)
}

appURLCopy := z.appURL
Expand Down
10 changes: 5 additions & 5 deletions webstd/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,22 @@ func (a Authenticator) VerifyIDToken(ctx context.Context, token *oauth2.Token) (
}

// RefreshIDToken obtains a new OIDC ID token using the provided refresh token.
func (a Authenticator) RefreshIDToken(ctx context.Context, refreshToken string) (string, *oidc.IDToken, string, error) {
func (a Authenticator) RefreshIDToken(ctx context.Context, refreshToken string) (string, *oidc.IDToken, *oauth2.Token, error) {
ts := a.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken})
token, err := ts.Token()
if err != nil {
return "", nil, "", err
return "", nil, nil, err
}

rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return "", nil, "", errors.New("no id_token field in oauth2 token")
return "", nil, nil, errors.New("no id_token field in oauth2 token")
}
idToken, err := a.VerifyIDToken(ctx, token)
if err != nil {
return "", nil, "", err
return "", nil, nil, err
}
return rawIDToken, idToken, token.RefreshToken, nil
return rawIDToken, idToken, token, nil
}

// VerifyRawToken verifies a given raw JWT token string issued by the OIDC provider. This is useful for verifying tokens
Expand Down

0 comments on commit cfd1d5b

Please sign in to comment.