Permalink
Browse files

Merge pull request #181 from mraerino/feature/saml-provider

Implement SAML 2 external provider
  • Loading branch information...
rybit committed Sep 24, 2018
2 parents cadf765 + 9bc8356 commit c9327d42aacc622c9dac493c8126f01f15ffc44d
View
@@ -0,0 +1,3 @@
/hack/
/vendor/
/www/
View
@@ -121,6 +121,15 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
})
})
})
r.Route("/saml", func(r *router) {
r.Route("/acs", func(r *router) {
r.Use(api.loadSAMLState)
r.Post("/", api.ExternalProviderCallback)
})
r.Get("/metadata", api.SAMLMetadata)
})
})
if globalConfig.MultiInstanceMode {
View
@@ -86,43 +86,26 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
ctx := r.Context()
config := a.getConfig(ctx)
instanceID := getInstanceID(ctx)
rq := r.URL.Query()
extError := rq.Get("error")
if extError != "" {
return oauthError(extError, rq.Get("error_description"))
}
oauthCode := rq.Get("code")
if oauthCode == "" {
return badRequestError("Authorization code missing")
}
providerType := getExternalProviderType(ctx)
provider, err := a.Provider(ctx, providerType)
if err != nil {
return badRequestError("Unsupported provider: %+v", err).WithInternalError(err)
}
log := getLogEntry(r)
log.WithFields(logrus.Fields{
"provider": providerType,
"code": oauthCode,
}).Debug("Exchanging oauth code")
tok, err := provider.GetOAuthToken(oauthCode)
if err != nil {
return internalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err)
}
userData, err := provider.GetUserData(ctx, tok)
if err != nil {
return internalServerError("Error getting user email from external provider").WithInternalError(err)
var userData *provider.UserProvidedData
if providerType == "saml" {
samlUserData, err := a.samlCallback(r, ctx)
if err != nil {
return err
}
userData = samlUserData
} else {
oAuthUserData, err := a.oAuthCallback(r, ctx, providerType)
if err != nil {
return err
}
userData = oAuthUserData
}
var user *models.User
var token *AccessTokenResponse
err = a.db.Transaction(func(tx *storage.Connection) error {
err := a.db.Transaction(func(tx *storage.Connection) error {
var terr error
inviteToken := getInviteToken(ctx)
if inviteToken != "" {
@@ -257,15 +240,7 @@ func (a *API) processInvite(ctx context.Context, tx *storage.Connection, userDat
return user, nil
}
// loadOAuthState parses the `state` query parameter as a JWS payload,
// extracting the provider requested
func (a *API) loadOAuthState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx := r.Context()
state := r.URL.Query().Get("state")
if state == "" {
return nil, badRequestError("OAuth state parameter missing")
}
func (a *API) loadExternalState(ctx context.Context, state string) (context.Context, error) {
claims := ExternalProviderClaims{}
p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}}
_, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) {
@@ -301,6 +276,8 @@ func (a *API) Provider(ctx context.Context, name string) (provider.Provider, err
return provider.NewGoogleProvider(config.External.Google)
case "facebook":
return provider.NewFacebookProvider(config.External.Facebook)
case "saml":
return provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx))
default:
return nil, fmt.Errorf("Provider %s could not be found", name)
}
View
@@ -0,0 +1,72 @@
package api
import (
"context"
"net/http"
"github.com/netlify/gotrue/api/provider"
"github.com/sirupsen/logrus"
)
// loadOAuthState parses the `state` query parameter as a JWS payload,
// extracting the provider requested
func (a *API) loadOAuthState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
state := r.URL.Query().Get("state")
if state == "" {
return nil, badRequestError("OAuth state parameter missing")
}
ctx := r.Context()
return a.loadExternalState(ctx, state)
}
func (a *API) oAuthCallback(r *http.Request, ctx context.Context, providerType string) (*provider.UserProvidedData, error) {
rq := r.URL.Query()
extError := rq.Get("error")
if extError != "" {
return nil, oauthError(extError, rq.Get("error_description"))
}
oauthCode := rq.Get("code")
if oauthCode == "" {
return nil, badRequestError("Authorization code missing")
}
oAuthProvider, err := a.OAuthProvider(ctx, providerType)
if err != nil {
return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err)
}
log := getLogEntry(r)
log.WithFields(logrus.Fields{
"provider": providerType,
"code": oauthCode,
}).Debug("Exchanging oauth code")
tok, err := oAuthProvider.GetOAuthToken(oauthCode)
if err != nil {
return nil, internalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err)
}
userData, err := oAuthProvider.GetUserData(ctx, tok)
if err != nil {
return nil, internalServerError("Error getting user email from external provider").WithInternalError(err)
}
return userData, nil
}
func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthProvider, error) {
providerCandidate, err := a.Provider(ctx, name)
if err != nil {
return nil, err
}
switch p := providerCandidate.(type) {
case provider.OAuthProvider:
return p, nil
default:
return nil, badRequestError("Provider can not be used for OAuth")
}
}
View
@@ -0,0 +1,70 @@
package api
import (
"context"
"net/http"
"github.com/netlify/gotrue/api/provider"
)
func (a *API) loadSAMLState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
state := r.FormValue("RelayState")
if state == "" {
return nil, badRequestError("SAML RelayState is missing")
}
ctx := r.Context()
return a.loadExternalState(ctx, state)
}
func (a *API) samlCallback(r *http.Request, ctx context.Context) (*provider.UserProvidedData, error) {
config := a.getConfig(ctx)
samlProvider, err := provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx))
if err != nil {
return nil, badRequestError("Could not initialize SAML provider: %+v", err).WithInternalError(err)
}
samlResponse := r.FormValue("SAMLResponse")
if samlResponse == "" {
return nil, badRequestError("SAML Response is missing")
}
assertionInfo, err := samlProvider.ServiceProvider.RetrieveAssertionInfo(samlResponse)
if err != nil {
return nil, internalServerError("Parsing SAML assertion failed: %+v", err).WithInternalError(err)
}
if assertionInfo.WarningInfo.InvalidTime {
return nil, forbiddenError("SAML response has invalid time")
}
if assertionInfo.WarningInfo.NotInAudience {
return nil, forbiddenError("SAML response is not in audience")
}
if assertionInfo == nil {
return nil, internalServerError("SAML Assertion is missing")
}
userData := &provider.UserProvidedData{
Email: assertionInfo.NameID,
Verified: true,
}
return userData, nil
}
func (a *API) SAMLMetadata(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
config := getConfig(ctx)
samlProvider, err := provider.NewSamlProvider(config.External.Saml, a.db, getInstanceID(ctx))
if err != nil {
return internalServerError("Could not create SAML Provider: %+v", err).WithInternalError(err)
}
metadata, err := samlProvider.SPMetadata()
w.Header().Set("Content-Type", "application/xml")
w.Write(metadata)
return nil
}
Oops, something went wrong.

0 comments on commit c9327d4

Please sign in to comment.