Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Server] handling of token and cookies, auth and session middleware #11003

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion server/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ func main() {
lProv.SeedContent(log)
provs[lProv.Name()] = lProv

providerEnvVar := viper.GetString(constants.ProviderENV)
RemoteProviderURLs := viper.GetStringSlice("PROVIDER_BASE_URLS")
for _, providerurl := range RemoteProviderURLs {
parsedURL, err := url.Parse(providerurl)
Expand All @@ -286,6 +287,7 @@ func main() {
GenericPersister: dbHandler,
EventsPersister: &models.EventsPersister{DB: dbHandler},
Log: log,
CookieDuration: 24 * time.Hour,
}

cp.Initialize()
Expand All @@ -295,6 +297,14 @@ func main() {
provs[cp.Name()] = cp
}

// verifies if the provider specified in the "PROVIDER" environment variable is from one of the supported providers.
// If it is one of the supported providers, the server gets configured to auto select the specified provider,
// else the provider specified in the environment variable is ignored and each time user logs in they need to select a provider.
isProviderEnvVarValid := models.VerifyMesheryProvider(providerEnvVar, provs)
if !isProviderEnvVarValid {
providerEnvVar = ""
}

operatorDeploymentConfig := models.NewOperatorDeploymentConfig(adapterTracker)
mctrlHelper := models.NewMesheryControllersHelper(log, operatorDeploymentConfig, dbHandler)
connToInstanceTracker := machines.ConnectionToStateMachineInstanceTracker{
Expand All @@ -305,7 +315,7 @@ func main() {

models.InitMeshSyncRegistrationQueue()
mhelpers.InitRegistrationHelperSingleton(dbHandler, log, &connToInstanceTracker, hc.EventBroadcaster)
h := handlers.NewHandlerInstance(hc, meshsyncCh, log, brokerConn, k8sComponentsRegistrationHelper, mctrlHelper, dbHandler, events.NewEventStreamer(), regManager, viper.GetString(constants.ProviderENV), &rego, &connToInstanceTracker)
h := handlers.NewHandlerInstance(hc, meshsyncCh, log, brokerConn, k8sComponentsRegistrationHelper, mctrlHelper, dbHandler, events.NewEventStreamer(), regManager, providerEnvVar, &rego, &connToInstanceTracker)

b := broadcast.NewBroadcaster(100)
defer b.Close()
Expand Down
36 changes: 20 additions & 16 deletions server/handlers/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ func (h *Handler) AuthMiddleware(next http.Handler, auth models.AuthenticationMe
http.Redirect(w, req, "/provider", http.StatusFound)
return
}
if providerH != "" && providerH != provider.Name() {
w.WriteHeader(http.StatusUnauthorized)
return
}

// Because server verifies the value of the "PROVIDER" environemnt variable and doesn't allow unsupported provider value,
// the below situation cannot occur.

// if providerH != "" && providerH != provider.Name() {
// w.WriteHeader(http.StatusUnauthorized)
// return
// }
// logrus.Debugf("provider %s", provider)
isValid := h.validateAuth(provider, req)
// logrus.Debugf("validate auth: %t", isValid)
Expand Down Expand Up @@ -145,18 +149,18 @@ func (h *Handler) SessionInjectorMiddleware(next func(http.ResponseWriter, *http
return
}
// ensuring session is intact
err := provider.GetSession(req)
if err != nil {
err1 := provider.Logout(w, req)
if err1 != nil {
logrus.Errorf("Error performing logout: %v", err1.Error())
provider.HandleUnAuthenticated(w, req)
return
}
logrus.Errorf("Error: unable to get session: %v", err)
http.Error(w, "unable to get session", http.StatusUnauthorized)
return
}
// err := provider.GetSession(req)
// if err != nil {
// err1 := provider.Logout(w, req)
// if err1 != nil {
// logrus.Errorf("Error performing logout: %v", err1.Error())
// provider.HandleUnAuthenticated(w, req)
// return
// }
// logrus.Errorf("Error: unable to get session: %v", err)
// http.Error(w, "unable to get session", http.StatusUnauthorized)
// return
// }

user, err := provider.GetUserDetails(req)
// if user details are not available,
Expand Down
4 changes: 2 additions & 2 deletions server/handlers/provider_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ func (h *Handler) ProvidersHandler(w http.ResponseWriter, _ *http.Request) {

// ProviderUIHandler - serves providers UI
func (h *Handler) ProviderUIHandler(w http.ResponseWriter, r *http.Request) {
if h.config.PlaygroundBuild || h.Provider == "Meshery" { //Always use Remote provider for Playground build or when Provider is enforced
if h.config.PlaygroundBuild || h.Provider != "" { //Always use Remote provider for Playground build or when Provider is enforced
http.SetCookie(w, &http.Cookie{
Name: h.config.ProviderCookieName,
Value: "Meshery",
Value: h.Provider,
Path: "/",
HttpOnly: true,
})
Expand Down
2 changes: 1 addition & 1 deletion server/models/default_local_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ func (l *DefaultLocalProvider) TokenHandler(_ http.ResponseWriter, _ *http.Reque
func (l *DefaultLocalProvider) ExtractToken(w http.ResponseWriter, _ *http.Request) {
resp := map[string]interface{}{
"meshery-provider": l.Name(),
tokenName: "",
TokenCookieName: "",
}
logrus.Debugf("token sent for meshery-provider %v", l.Name())
if err := json.NewEncoder(w).Encode(resp); err != nil {
Expand Down
9 changes: 9 additions & 0 deletions server/models/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,15 @@ func (caps Capabilities) GetEndpointForFeature(feature Feature) (string, bool) {
return "", false
}

func VerifyMesheryProvider(provider string, supportedProviders map[string]Provider) bool {
for prov := range supportedProviders {
if prov == provider {
return true
}
}
return false
}

// Provider - interface for providers
type Provider interface {
PreferencePersister
Expand Down
60 changes: 56 additions & 4 deletions server/models/remote_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ import (
"golang.org/x/oauth2"
)

const (
// Stores meshery provider related info.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add these to docs.

ProviderCookieName = "meshery-provider"

// Stores the JWT issued by the remote provider to provide secure access to its API
TokenCookieName = "token"

// Stores the remote provider session cookie (identity cookie) to facilitate logout from remote provider as user logs out of Meshery
ProviderSessionCookieName = "session_cookie"
)

// JWK - a type respresting the JSON web Key
type JWK map[string]string

Expand Down Expand Up @@ -66,7 +77,7 @@ func (l *RemoteProvider) refreshToken(tokenString string) (string, error) {
return newTokenString, nil
}
bd := map[string]string{
tokenName: tokenString,
TokenCookieName: tokenString,
}
jsonString, err := json.Marshal(bd)
if err != nil {
Expand All @@ -86,12 +97,12 @@ func (l *RemoteProvider) refreshToken(tokenString string) (string, error) {
if err != nil {
return "", err
}
l.TokenStore[tokenString] = target[tokenName]
l.TokenStore[tokenString] = target[TokenCookieName]
time.AfterFunc(1*time.Hour, func() {
logrus.Infof("deleting old token string")
delete(l.TokenStore, tokenString)
})
return target[tokenName], nil
return target[TokenCookieName], nil
}

func (l *RemoteProvider) doRequestHelper(req *http.Request, tokenString string) (*http.Response, error) {
Expand All @@ -111,7 +122,7 @@ func (l *RemoteProvider) doRequestHelper(req *http.Request, tokenString string)

// GetToken - extracts token form a request
func (l *RemoteProvider) GetToken(req *http.Request) (string, error) {
ck, err := req.Cookie(tokenName)
ck, err := req.Cookie(TokenCookieName)
if err != nil {
return "", ErrGetToken(err)
}
Expand Down Expand Up @@ -345,3 +356,44 @@ func (l *RemoteProvider) introspectToken(tokenString string) error {

return nil
}

func setCookie(w http.ResponseWriter, name, value string, duration time.Duration) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
Path: "/",
HttpOnly: true,
Expires: time.Now().Add(duration),
})
}

func unsetCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{
Name: name,
MaxAge: -1,
})
}

func (l *RemoteProvider) SetProviderCookie(w http.ResponseWriter, provider string) {
setCookie(w, ProviderCookieName, provider, l.CookieDuration)
}

func (l *RemoteProvider) UnSetProviderCookie(w http.ResponseWriter) {
unsetCookie(w, ProviderCookieName)
}

func (l *RemoteProvider) SetJWTCookie(w http.ResponseWriter, token string) {
setCookie(w, TokenCookieName, token, l.CookieDuration)
}

func (l *RemoteProvider) UnSetJWTCookie(w http.ResponseWriter) {
unsetCookie(w, TokenCookieName)
}

func (l *RemoteProvider) SetProviderSessionCookie(w http.ResponseWriter, sessionCookie string) {
setCookie(w, ProviderSessionCookieName, sessionCookie, l.CookieDuration)
}

func (l *RemoteProvider) UnSetProviderSessionCookie(w http.ResponseWriter) {
unsetCookie(w, ProviderSessionCookieName)
}
64 changes: 18 additions & 46 deletions server/models/remote_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ type RemoteProvider struct {

LoginCookieDuration time.Duration

// provider and token cookie expiry bound
CookieDuration time.Duration

syncStopChan chan struct{}
syncChan chan *userSession

Expand Down Expand Up @@ -178,8 +181,6 @@ func (l *RemoteProvider) Description() []string {
return l.ProviderDescription
}

const tokenName = "token"

// GetProviderType - Returns ProviderType
func (l *RemoteProvider) GetProviderType() ProviderType {
return l.ProviderType
Expand Down Expand Up @@ -269,7 +270,7 @@ func (l *RemoteProvider) InitiateLogin(w http.ResponseWriter, r *http.Request, _
callbackURL := r.Context().Value(MesheryServerCallbackURL).(string)
mesheryVersion := viper.GetString("BUILD")

_, err := r.Cookie(tokenName)
_, err := r.Cookie(TokenCookieName)
if err != nil {
http.SetCookie(w, &http.Cookie{
Name: l.RefCookieName,
Expand Down Expand Up @@ -573,10 +574,8 @@ func (l *RemoteProvider) Logout(w http.ResponseWriter, req *http.Request) error
// make request to remote provider with contructed URL and updated headers (like session_cookie, return_to cookies)
resp, err := l.DoRequest(cReq, tokenString)
if err != nil {
if resp == nil {
return ErrUnreachableRemoteProvider(err)
}
logrus.Errorf("error performing logout: %v", err)
err = ErrUnreachableRemoteProvider(err)
l.Log.Error(err)
return err
}

Expand All @@ -593,21 +592,16 @@ func (l *RemoteProvider) Logout(w http.ResponseWriter, req *http.Request) error
// And empties the token and session cookies
if resp.StatusCode == http.StatusFound || resp.StatusCode == http.StatusOK {
// gets the token from the request headers
ck, err := req.Cookie(tokenName)
ck, err := req.Cookie(TokenCookieName)
if err == nil {
err = l.revokeToken(ck.Value)
}
if err != nil {
logrus.Errorf("error performing logout, token cannot be revoked: %v", err)
http.Redirect(w, req, "/user/login", http.StatusFound)
return nil
}
ck.MaxAge = -1
ck.Path = "/"
http.SetCookie(w, ck)
sessionCookie.MaxAge = -1
sessionCookie.Path = "/"
http.SetCookie(w, sessionCookie)
}
l.UnSetJWTCookie(w)

l.UnSetProviderSessionCookie(w)
return nil
}

Expand All @@ -621,13 +615,8 @@ func (l *RemoteProvider) Logout(w http.ResponseWriter, req *http.Request) error
func (l *RemoteProvider) HandleUnAuthenticated(w http.ResponseWriter, req *http.Request) {
_, err := req.Cookie("meshery-provider")
if err == nil {
ck, err := req.Cookie(tokenName)
if err == nil {
ck.MaxAge = -1
ck.Path = "/"
http.SetCookie(w, ck)
}

// remove the cookie from the browser and redirect to inform about expired session.
l.UnSetJWTCookie(w)
http.Redirect(w, req, "/auth/login", http.StatusFound)
return
}
Expand Down Expand Up @@ -3313,25 +3302,13 @@ func (l *RemoteProvider) RecordPreferences(req *http.Request, userID string, dat

// TokenHandler - specific to remote auth
func (l *RemoteProvider) TokenHandler(w http.ResponseWriter, r *http.Request, _ bool) {
tokenString := r.URL.Query().Get(tokenName)
tokenString := r.URL.Query().Get(TokenCookieName)
// gets the session cookie from remote provider
sessionCookie := r.URL.Query().Get("session_cookie")

ck := &http.Cookie{
Name: tokenName,
Value: string(tokenString),
Path: "/",
Expires: time.Now().Add(24 * time.Hour),
HttpOnly: true,
}
http.SetCookie(w, ck)
l.SetJWTCookie(w, tokenString)
// sets the session cookie for Meshery Session
http.SetCookie(w, &http.Cookie{
Name: "session_cookie",
Value: sessionCookie,
Path: "/",
HttpOnly: true,
})
l.SetProviderSessionCookie(w, sessionCookie)

// Get new capabilities
// Doing this here is important so that
Expand Down Expand Up @@ -3391,12 +3368,7 @@ func (l *RemoteProvider) UpdateToken(w http.ResponseWriter, r *http.Request) str
newts := l.TokenStore[tokenString]
if newts != "" {
logrus.Debugf("set updated token: %v", newts)
http.SetCookie(w, &http.Cookie{
Name: tokenName,
Value: newts,
Path: "/",
HttpOnly: true,
})
l.SetJWTCookie(w, newts)
return newts
}

Expand All @@ -3420,7 +3392,7 @@ func (l *RemoteProvider) ExtractToken(w http.ResponseWriter, r *http.Request) {

resp := map[string]interface{}{
"meshery-provider": l.Name(),
tokenName: tokenString,
TokenCookieName: tokenString,
}
logrus.Debugf("token sent for meshery-provider %v", l.Name())
if err := json.NewEncoder(w).Encode(resp); err != nil {
Expand Down