From e89a0495a656949481555f9589f8fbd05736fb5f Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Tue, 10 Jan 2023 18:23:28 -0800 Subject: [PATCH] Store additional claims in the QueryUserInfoFromAccessToken path Signed-off-by: Haytham Abuelfutuh --- auth/handlers.go | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/auth/handlers.go b/auth/handlers.go index d3e451295..d3199d621 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -8,6 +8,9 @@ import ( "strings" "time" + _struct "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/protobuf/encoding/protojson" + "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" @@ -410,16 +413,41 @@ func QueryUserInfoUsingAccessToken(ctx context.Context, originalRequest *http.Re userInfo, err := authCtx.OidcProvider().UserInfo(ctx, tokenSource) if err != nil { logger.Errorf(ctx, "Error getting user info from IDP %s", err) - return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP. Error: %w", err) } resp := &service.UserInfoResponse{} err = userInfo.Claims(&resp) if err != nil { logger.Errorf(ctx, "Error getting user info from IDP %s", err) - return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + return &service.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP. Error: %w", err) + } + + allClaims := make(map[string]any, 10) + err = userInfo.Claims(&allClaims) + if err != nil { + logger.Errorf(ctx, "Error unmarshalling raw claims %s", err) + return &service.UserInfoResponse{}, fmt.Errorf("error unmarshalling raw claims. Error: %w", err) } + alreadyRead := []string{"subject", "name", "preferred_username", "given_name", "family_name", "email", "picture"} + for _, existing := range alreadyRead { + delete(allClaims, existing) + } + + var response _struct.Struct + b, err := json.Marshal(allClaims) + if err != nil { + return &service.UserInfoResponse{}, fmt.Errorf("failed to marshal additional claims to json. Error: %w", err) + } + + err = protojson.Unmarshal(b, &response) + if err != nil { + return nil, fmt.Errorf("failed to unamarshal additional claims to proto.struct. Error: %w", err) + } + + resp.AdditionalClaims = &response + return resp, err }