diff --git a/docs/config.md b/docs/config.md index 5bcbba5..10e3492 100644 --- a/docs/config.md +++ b/docs/config.md @@ -46,8 +46,8 @@ Database Playground 使用 PostgreSQL 作為資料庫。 - `GAUTH_CLIENT_ID`:Google OAuth 的 Client ID - `GAUTH_CLIENT_SECRET`:Google OAuth 的 Client Secret -- `GAUTH_REDIRECT_URI`:在完成 Google OAuth 流程後,要重新導向到的 URL,通常是指向前端。 - - 舉例:你的前端會在進入起始連結前,記錄目前頁面的位址,然後在 `/auth/completed` endpoint 重新導向回使用者上次瀏覽的連結。這時,你可以將重新導向連結寫為 `https://app.yourdomain.tld/auth/completed`。如果你沒有這樣的 endpoint,寫上前端的首頁也是可以的。注意在起始連結帶入的 `state` 會被帶入這個 URI 中。 +- `GAUTH_REDIRECT_URIS`:在完成 Google OAuth 流程後,允許重新導向到的 URIs。 + - 舉例:`https://admin.dbplay.app` Google OAuth 的登入起始連結為 `https://backend.yourdomain.tld/api/auth/google/login`,可選擇性帶入 `state` 參數。 Google OAuth 的回呼連結為 `https://backend.yourdomain.tld/api/auth/google/callback`。 diff --git a/httpapi/auth/README.md b/httpapi/auth/README.md index e6b82ab..91157ab 100644 --- a/httpapi/auth/README.md +++ b/httpapi/auth/README.md @@ -27,6 +27,6 @@ Auth 端點提供適合供網頁應用程式使用的認證 API。 ## Google 登入 -如果您要觸發 Google 登入的流程,請前往 `GET /api/auth/google/login`。 +如果您要觸發 Google 登入的流程,請前往 `GET /api/auth/google/login`。可以帶入 `redirect_uri` 參數來在登入完成後轉導到指定畫面。 這個頁面會重新導向到 Google 的登入頁面,登入後會回到 `POST /api/auth/google/callback` 並進行帳號登入和註冊手續。 diff --git a/httpapi/auth/gauth.go b/httpapi/auth/gauth.go index 5e3579d..88c7317 100644 --- a/httpapi/auth/gauth.go +++ b/httpapi/auth/gauth.go @@ -1,6 +1,8 @@ package authservice import ( + "errors" + "fmt" "net/http" "net/url" @@ -16,7 +18,10 @@ import ( "google.golang.org/api/option" ) -const verifierCookieName = "Gauth-Verifier" +const ( + verifierCookieName = "Gauth-Verifier" + redirectCookieName = "Gauth-Redirect" +) // BuildOAuthConfig builds an oauth2.Config from a gauthConfig. func BuildOAuthConfig(gauthConfig config.GAuthConfig) *oauth2.Config { @@ -32,13 +37,13 @@ func BuildOAuthConfig(gauthConfig config.GAuthConfig) *oauth2.Config { } type GauthHandler struct { - oauthConfig *oauth2.Config - useraccount *useraccount.Context - redirectURL string + oauthConfig *oauth2.Config + useraccount *useraccount.Context + redirectURIs []string } -func NewGauthHandler(oauthConfig *oauth2.Config, useraccount *useraccount.Context, redirectURL string) *GauthHandler { - return &GauthHandler{oauthConfig: oauthConfig, useraccount: useraccount, redirectURL: redirectURL} +func NewGauthHandler(oauthConfig *oauth2.Config, useraccount *useraccount.Context, redirectURIs []string) *GauthHandler { + return &GauthHandler{oauthConfig: oauthConfig, useraccount: useraccount, redirectURIs: redirectURIs} } func (h *GauthHandler) Login(c *gin.Context) { @@ -55,6 +60,11 @@ func (h *GauthHandler) Login(c *gin.Context) { return } + redirectURI := c.Query("redirect_uri") + if redirectURI == "" { + redirectURI = h.oauthConfig.RedirectURL + } + callbackURL, err := url.Parse(h.oauthConfig.RedirectURL) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ @@ -74,6 +84,16 @@ func (h *GauthHandler) Login(c *gin.Context) { /* httpOnly */ true, ) + c.SetCookie( + /* name */ redirectCookieName, + /* value */ redirectURI, + /* maxAge */ 5*60, // 5 min + /* path */ "/", + /* domain */ "", + /* secure */ true, + /* httpOnly */ true, + ) + redirectURL := h.oauthConfig.AuthCodeURL( "", oauth2.AccessTypeOnline, @@ -160,5 +180,55 @@ func (h *GauthHandler) Callback(c *gin.Context) { /* httpOnly */ true, ) - c.Redirect(http.StatusTemporaryRedirect, h.redirectURL) + // redirect to the original redirect URL + redirectURL, err := c.Cookie(redirectCookieName) + if err != nil { + if errors.Is(err, http.ErrNoCookie) { + c.JSON(http.StatusOK, gin.H{ + "success": true, + }) + return + } + + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to get redirect URL", + "detail": err.Error(), + }) + return + } + + // check if the redirect URL is in the allowed redirect URIs + userRedirectURL, err := url.Parse(redirectURL) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to parse redirect URL", + "detail": err.Error(), + }) + return + } + + for _, allowedRedirectURI := range h.redirectURIs { + parsedAllowedRedirectURI, err := url.Parse(allowedRedirectURI) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to parse allowed redirect URI", + "detail": err.Error(), + }) + return + } + + matched := userRedirectURL.Scheme == parsedAllowedRedirectURI.Scheme && + userRedirectURL.Host == parsedAllowedRedirectURI.Host && + userRedirectURL.Path == parsedAllowedRedirectURI.Path + + if matched { + c.Redirect(http.StatusTemporaryRedirect, parsedAllowedRedirectURI.String()) + return + } + } + + c.JSON(http.StatusBadRequest, gin.H{ + "error": "redirect URL is not allowed", + "detail": fmt.Sprintf("redirect URL is not allowed: %s", redirectURL), + }) } diff --git a/httpapi/auth/root.go b/httpapi/auth/root.go index 7aabf10..0b4a6e0 100644 --- a/httpapi/auth/root.go +++ b/httpapi/auth/root.go @@ -34,7 +34,7 @@ func (s *AuthService) Register(router gin.IRouter) { useraccount := useraccount.NewContext(s.entClient, s.storage) - gauthHandler := NewGauthHandler(oauthConfig, useraccount, s.config.GAuth.RedirectURL) + gauthHandler := NewGauthHandler(oauthConfig, useraccount, s.config.GAuth.RedirectURIs) gauth.GET("/login", gauthHandler.Login) gauth.GET("/callback", gauthHandler.Callback) diff --git a/internal/config/models.go b/internal/config/models.go index f69c9c5..2c36602 100644 --- a/internal/config/models.go +++ b/internal/config/models.go @@ -71,9 +71,9 @@ func (c RedisConfig) Validate() error { } type GAuthConfig struct { - ClientID string `env:"CLIENT_ID"` - ClientSecret string `env:"CLIENT_SECRET"` - RedirectURL string `env:"REDIRECT_URL"` + ClientID string `env:"CLIENT_ID"` + ClientSecret string `env:"CLIENT_SECRET"` + RedirectURIs []string `env:"REDIRECT_URIS"` } func (c GAuthConfig) Validate() error { @@ -83,8 +83,8 @@ func (c GAuthConfig) Validate() error { if c.ClientSecret == "" { return errors.New("GAUTH_CLIENT_SECRET is required") } - if c.RedirectURL == "" { - return errors.New("GAUTH_REDIRECT_URL is required") + if len(c.RedirectURIs) == 0 { + return errors.New("GAUTH_REDIRECT_URIS is required") } return nil