diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index a4c4a064226a..a3599bc7a24e 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -165,6 +165,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *m.ReqContext) { extUser := &m.ExternalUserInfo{ AuthModule: "oauth_" + name, + OAuthToken: token, AuthId: userInfo.Id, Name: userInfo.Name, Login: userInfo.Login, diff --git a/pkg/api/pluginproxy/ds_proxy.go b/pkg/api/pluginproxy/ds_proxy.go index 3aec988f9e35..5f314bc2421f 100644 --- a/pkg/api/pluginproxy/ds_proxy.go +++ b/pkg/api/pluginproxy/ds_proxy.go @@ -14,8 +14,11 @@ import ( "time" "github.com/opentracing/opentracing-go" + "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/log" + "github.com/grafana/grafana/pkg/login/social" m "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/setting" @@ -221,6 +224,10 @@ func (proxy *DataSourceProxy) getDirector() func(req *http.Request) { if proxy.route != nil { ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds) } + + if proxy.ds.JsonData != nil && proxy.ds.JsonData.Get("oauthPassThru").MustBool() { + addOAuthPassThruAuth(proxy.ctx, req) + } } } @@ -311,3 +318,46 @@ func checkWhiteList(c *m.ReqContext, host string) bool { return true } + +func addOAuthPassThruAuth(c *m.ReqContext, req *http.Request) { + authInfoQuery := &m.GetAuthInfoQuery{UserId: c.UserId} + if err := bus.Dispatch(authInfoQuery); err != nil { + logger.Error("Error feching oauth information for user", "error", err) + return + } + + provider := authInfoQuery.Result.AuthModule + connect, ok := social.SocialMap[strings.TrimPrefix(provider, "oauth_")] // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does + if !ok { + logger.Error("Failed to find oauth provider with given name", "provider", provider) + return + } + + // TokenSource handles refreshing the token if it has expired + token, err := connect.TokenSource(c.Req.Context(), &oauth2.Token{ + AccessToken: authInfoQuery.Result.OAuthAccessToken, + Expiry: authInfoQuery.Result.OAuthExpiry, + RefreshToken: authInfoQuery.Result.OAuthRefreshToken, + TokenType: authInfoQuery.Result.OAuthTokenType, + }).Token() + if err != nil { + logger.Error("Failed to retrieve access token from oauth provider", "provider", authInfoQuery.Result.AuthModule) + return + } + + // If the tokens are not the same, update the entry in the DB + if token.AccessToken != authInfoQuery.Result.OAuthAccessToken { + updateAuthCommand := &m.UpdateAuthInfoCommand{ + UserId: authInfoQuery.Result.Id, + AuthModule: authInfoQuery.Result.AuthModule, + AuthId: authInfoQuery.Result.AuthId, + OAuthToken: token, + } + if err := bus.Dispatch(updateAuthCommand); err != nil { + logger.Error("Failed to update access token during token refresh", "error", err) + return + } + } + req.Header.Del("Authorization") + req.Header.Add("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken)) +} diff --git a/pkg/api/pluginproxy/ds_proxy_test.go b/pkg/api/pluginproxy/ds_proxy_test.go index bfad7d5670d6..368acb3a6429 100644 --- a/pkg/api/pluginproxy/ds_proxy_test.go +++ b/pkg/api/pluginproxy/ds_proxy_test.go @@ -9,10 +9,13 @@ import ( "testing" "time" + "golang.org/x/oauth2" macaron "gopkg.in/macaron.v1" + "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/log" + "github.com/grafana/grafana/pkg/login/social" m "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/setting" @@ -389,6 +392,54 @@ func TestDSRouteRule(t *testing.T) { }) }) + Convey("When proxying a datasource that has oauth token pass-thru enabled", func() { + social.SocialMap["generic_oauth"] = &social.SocialGenericOAuth{ + SocialBase: &social.SocialBase{ + Config: &oauth2.Config{}, + }, + } + + bus.AddHandler("test", func(query *m.GetAuthInfoQuery) error { + query.Result = &m.UserAuth{ + Id: 1, + UserId: 1, + AuthModule: "generic_oauth", + OAuthAccessToken: "testtoken", + OAuthRefreshToken: "testrefreshtoken", + OAuthTokenType: "Bearer", + OAuthExpiry: time.Now().AddDate(0, 0, 1), + } + return nil + }) + + plugin := &plugins.DataSourcePlugin{} + ds := &m.DataSource{ + Type: "custom-datasource", + Url: "http://host/root/", + JsonData: simplejson.NewFromAny(map[string]interface{}{ + "oauthPassThru": true, + }), + } + + req, _ := http.NewRequest("GET", "http://localhost/asd", nil) + ctx := &m.ReqContext{ + SignedInUser: &m.SignedInUser{UserId: 1}, + Context: &macaron.Context{ + Req: macaron.Request{Request: req}, + }, + } + proxy := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{}) + req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil) + + So(err, ShouldBeNil) + + proxy.getDirector()(req) + + Convey("Should have access token in header", func() { + So(req.Header.Get("Authorization"), ShouldEqual, fmt.Sprintf("%s %s", "Bearer", "testtoken")) + }) + }) + Convey("When SendUserHeader config is enabled", func() { req := getDatasourceProxiedRequest( &m.ReqContext{ diff --git a/pkg/login/ext_user.go b/pkg/login/ext_user.go index f217f9fe33c2..e698110c9c98 100644 --- a/pkg/login/ext_user.go +++ b/pkg/login/ext_user.go @@ -63,11 +63,12 @@ func (ls *LoginService) UpsertUser(cmd *m.UpsertUserCommand) error { return err } - if extUser.AuthModule != "" && extUser.AuthId != "" { + if extUser.AuthModule != "" { cmd2 := &m.SetAuthInfoCommand{ UserId: cmd.Result.Id, AuthModule: extUser.AuthModule, AuthId: extUser.AuthId, + OAuthToken: extUser.OAuthToken, } if err := ls.Bus.Dispatch(cmd2); err != nil { return err @@ -81,6 +82,14 @@ func (ls *LoginService) UpsertUser(cmd *m.UpsertUserCommand) error { if err != nil { return err } + + // Always persist the latest token at log-in + if extUser.AuthModule != "" && extUser.OAuthToken != nil { + err = updateUserAuth(cmd.Result, extUser) + if err != nil { + return err + } + } } err = syncOrgRoles(cmd.Result, extUser) @@ -155,6 +164,18 @@ func updateUser(user *m.User, extUser *m.ExternalUserInfo) error { return bus.Dispatch(updateCmd) } +func updateUserAuth(user *m.User, extUser *m.ExternalUserInfo) error { + updateCmd := &m.UpdateAuthInfoCommand{ + AuthModule: extUser.AuthModule, + AuthId: extUser.AuthId, + UserId: user.Id, + OAuthToken: extUser.OAuthToken, + } + + logger.Debug("Updating user_auth info", "user_id", user.Id) + return bus.Dispatch(updateCmd) +} + func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error { // don't sync org roles if none are specified if len(extUser.OrgRoles) == 0 { diff --git a/pkg/login/social/social.go b/pkg/login/social/social.go index 60099a028d6d..3ec0e2c96646 100644 --- a/pkg/login/social/social.go +++ b/pkg/login/social/social.go @@ -31,6 +31,7 @@ type SocialConnector interface { AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string Exchange(ctx context.Context, code string) (*oauth2.Token, error) Client(ctx context.Context, t *oauth2.Token) *http.Client + TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource } type SocialBase struct { diff --git a/pkg/models/user_auth.go b/pkg/models/user_auth.go index 28189005737c..11018c7fd716 100644 --- a/pkg/models/user_auth.go +++ b/pkg/models/user_auth.go @@ -2,17 +2,24 @@ package models import ( "time" + + "golang.org/x/oauth2" ) type UserAuth struct { - Id int64 - UserId int64 - AuthModule string - AuthId string - Created time.Time + Id int64 + UserId int64 + AuthModule string + AuthId string + Created time.Time + OAuthAccessToken string + OAuthRefreshToken string + OAuthTokenType string + OAuthExpiry time.Time } type ExternalUserInfo struct { + OAuthToken *oauth2.Token AuthModule string AuthId string UserId int64 @@ -39,6 +46,14 @@ type SetAuthInfoCommand struct { AuthModule string AuthId string UserId int64 + OAuthToken *oauth2.Token +} + +type UpdateAuthInfoCommand struct { + AuthModule string + AuthId string + UserId int64 + OAuthToken *oauth2.Token } type DeleteAuthInfoCommand struct { @@ -67,6 +82,7 @@ type GetUserByAuthInfoQuery struct { } type GetAuthInfoQuery struct { + UserId int64 AuthModule string AuthId string diff --git a/pkg/services/sqlstore/migrations/user_auth_mig.go b/pkg/services/sqlstore/migrations/user_auth_mig.go index 2771035b47f4..22be7790fa8d 100644 --- a/pkg/services/sqlstore/migrations/user_auth_mig.go +++ b/pkg/services/sqlstore/migrations/user_auth_mig.go @@ -25,4 +25,21 @@ func addUserAuthMigrations(mg *Migrator) { mg.AddMigration("alter user_auth.auth_id to length 190", NewRawSqlMigration(""). Postgres("ALTER TABLE user_auth ALTER COLUMN auth_id TYPE VARCHAR(190);"). Mysql("ALTER TABLE user_auth MODIFY auth_id VARCHAR(190);")) + + mg.AddMigration("Add OAuth access token to user_auth", NewAddColumnMigration(userAuthV1, &Column{ + Name: "o_auth_access_token", Type: DB_Text, Nullable: true, + })) + mg.AddMigration("Add OAuth refresh token to user_auth", NewAddColumnMigration(userAuthV1, &Column{ + Name: "o_auth_refresh_token", Type: DB_Text, Nullable: true, + })) + mg.AddMigration("Add OAuth token type to user_auth", NewAddColumnMigration(userAuthV1, &Column{ + Name: "o_auth_token_type", Type: DB_Text, Nullable: true, + })) + mg.AddMigration("Add OAuth expiry to user_auth", NewAddColumnMigration(userAuthV1, &Column{ + Name: "o_auth_expiry", Type: DB_DateTime, Nullable: true, + })) + + mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV1, &Index{ + Cols: []string{"user_id"}, + })) } diff --git a/pkg/services/sqlstore/user_auth.go b/pkg/services/sqlstore/user_auth.go index aec828451a46..fd8ec3d057f6 100644 --- a/pkg/services/sqlstore/user_auth.go +++ b/pkg/services/sqlstore/user_auth.go @@ -1,16 +1,22 @@ package sqlstore import ( + "encoding/base64" "time" "github.com/grafana/grafana/pkg/bus" m "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util" ) +var getTime = time.Now + func init() { bus.AddHandler("sql", GetUserByAuthInfo) bus.AddHandler("sql", GetAuthInfo) bus.AddHandler("sql", SetAuthInfo) + bus.AddHandler("sql", UpdateAuthInfo) bus.AddHandler("sql", DeleteAuthInfo) } @@ -94,7 +100,7 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error { } // create authInfo record to link accounts - if authQuery.Result == nil && query.AuthModule != "" && query.AuthId != "" { + if authQuery.Result == nil && query.AuthModule != "" { cmd2 := &m.SetAuthInfoCommand{ UserId: user.Id, AuthModule: query.AuthModule, @@ -111,10 +117,11 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error { func GetAuthInfo(query *m.GetAuthInfoQuery) error { userAuth := &m.UserAuth{ + UserId: query.UserId, AuthModule: query.AuthModule, AuthId: query.AuthId, } - has, err := x.Get(userAuth) + has, err := x.Desc("created").Get(userAuth) if err != nil { return err } @@ -122,6 +129,22 @@ func GetAuthInfo(query *m.GetAuthInfoQuery) error { return m.ErrUserNotFound } + secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken) + if err != nil { + return err + } + secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken) + if err != nil { + return err + } + secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType) + if err != nil { + return err + } + userAuth.OAuthAccessToken = secretAccessToken + userAuth.OAuthRefreshToken = secretRefreshToken + userAuth.OAuthTokenType = secretTokenType + query.Result = userAuth return nil } @@ -132,7 +155,27 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error { UserId: cmd.UserId, AuthModule: cmd.AuthModule, AuthId: cmd.AuthId, - Created: time.Now(), + Created: getTime(), + } + + if cmd.OAuthToken != nil { + secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken) + if err != nil { + return err + } + secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken) + if err != nil { + return err + } + secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType) + if err != nil { + return err + } + + authUser.OAuthAccessToken = secretAccessToken + authUser.OAuthRefreshToken = secretRefreshToken + authUser.OAuthTokenType = secretTokenType + authUser.OAuthExpiry = cmd.OAuthToken.Expiry } _, err := sess.Insert(authUser) @@ -140,9 +183,76 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error { }) } +func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error { + return inTransaction(func(sess *DBSession) error { + authUser := &m.UserAuth{ + UserId: cmd.UserId, + AuthModule: cmd.AuthModule, + AuthId: cmd.AuthId, + Created: getTime(), + } + + if cmd.OAuthToken != nil { + secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken) + if err != nil { + return err + } + secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken) + if err != nil { + return err + } + secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType) + if err != nil { + return err + } + + authUser.OAuthAccessToken = secretAccessToken + authUser.OAuthRefreshToken = secretRefreshToken + authUser.OAuthTokenType = secretTokenType + authUser.OAuthExpiry = cmd.OAuthToken.Expiry + } + + cond := &m.UserAuth{ + UserId: cmd.UserId, + AuthModule: cmd.AuthModule, + } + + _, err := sess.Update(authUser, cond) + return err + }) +} + func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error { return inTransaction(func(sess *DBSession) error { _, err := sess.Delete(cmd.UserAuth) return err }) } + +// decodeAndDecrypt will decode the string with the standard bas64 decoder +// and then decrypt it with grafana's secretKey +func decodeAndDecrypt(s string) (string, error) { + // Bail out if empty string since it'll cause a segfault in util.Decrypt + if s == "" { + return "", nil + } + decoded, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return "", err + } + decrypted, err := util.Decrypt(decoded, setting.SecretKey) + if err != nil { + return "", err + } + return string(decrypted), nil +} + +// encryptAndEncode will encrypt a string with grafana's secretKey, and +// then encode it with the standard bas64 encoder +func encryptAndEncode(s string) (string, error) { + encrypted, err := util.Encrypt([]byte(s), setting.SecretKey) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(encrypted), nil +} diff --git a/pkg/services/sqlstore/user_auth_test.go b/pkg/services/sqlstore/user_auth_test.go index a0dd714fe6f7..8a8213b8f878 100644 --- a/pkg/services/sqlstore/user_auth_test.go +++ b/pkg/services/sqlstore/user_auth_test.go @@ -4,8 +4,10 @@ import ( "context" "fmt" "testing" + "time" . "github.com/smartystreets/goconvey/convey" + "golang.org/x/oauth2" m "github.com/grafana/grafana/pkg/models" ) @@ -126,5 +128,97 @@ func TestUserAuth(t *testing.T) { So(err, ShouldEqual, m.ErrUserNotFound) So(query.Result, ShouldBeNil) }) + + Convey("Can set & retrieve oauth token information", func() { + token := &oauth2.Token{ + AccessToken: "testaccess", + RefreshToken: "testrefresh", + Expiry: time.Now(), + TokenType: "Bearer", + } + + // Find a user to set tokens on + login := "loginuser0" + + // Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table + query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"} + err = GetUserByAuthInfo(query) + + So(err, ShouldBeNil) + So(query.Result.Login, ShouldEqual, login) + + cmd := &m.UpdateAuthInfoCommand{ + UserId: query.Result.Id, + AuthId: query.AuthId, + AuthModule: query.AuthModule, + OAuthToken: token, + } + err = UpdateAuthInfo(cmd) + + So(err, ShouldBeNil) + + getAuthQuery := &m.GetAuthInfoQuery{ + UserId: query.Result.Id, + } + + err = GetAuthInfo(getAuthQuery) + + So(err, ShouldBeNil) + So(getAuthQuery.Result.OAuthAccessToken, ShouldEqual, token.AccessToken) + So(getAuthQuery.Result.OAuthRefreshToken, ShouldEqual, token.RefreshToken) + So(getAuthQuery.Result.OAuthTokenType, ShouldEqual, token.TokenType) + + }) + + Convey("Always return the most recently used auth_module", func() { + // Find a user to set tokens on + login := "loginuser0" + + // Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table + // Make the first log-in during the past + getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) } + query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"} + err = GetUserByAuthInfo(query) + getTime = time.Now + + So(err, ShouldBeNil) + So(query.Result.Login, ShouldEqual, login) + + // Add a second auth module for this user + // Have this module's last log-in be more recent + getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) } + query = &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"} + err = GetUserByAuthInfo(query) + getTime = time.Now + + So(err, ShouldBeNil) + So(query.Result.Login, ShouldEqual, login) + + // Get the latest entry by not supply an authmodule or authid + getAuthQuery := &m.GetAuthInfoQuery{ + UserId: query.Result.Id, + } + + err = GetAuthInfo(getAuthQuery) + + So(err, ShouldBeNil) + So(getAuthQuery.Result.AuthModule, ShouldEqual, "test2") + + // "log in" again with the first auth module + updateAuthCmd := &m.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "test1", AuthId: "test1"} + err = UpdateAuthInfo(updateAuthCmd) + + So(err, ShouldBeNil) + + // Get the latest entry by not supply an authmodule or authid + getAuthQuery = &m.GetAuthInfoQuery{ + UserId: query.Result.Id, + } + + err = GetAuthInfo(getAuthQuery) + + So(err, ShouldBeNil) + So(getAuthQuery.Result.AuthModule, ShouldEqual, "test1") + }) }) } diff --git a/public/app/features/datasources/partials/http_settings.html b/public/app/features/datasources/partials/http_settings.html index b4cf1084843c..755284bf7a8e 100644 --- a/public/app/features/datasources/partials/http_settings.html +++ b/public/app/features/datasources/partials/http_settings.html @@ -71,22 +71,25 @@

HTTP

Auth

- +
-
-
+
+ +
@@ -102,4 +105,4 @@
Basic Auth Details
- \ No newline at end of file +