Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions oauth2cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ type Config struct {
// OAuth2 config.
// RedirectURL will be automatically set to the local server.
OAuth2Config oauth2.Config
// Hostname of the redirect URL.
// You can set this if your provider does not accept localhost.
// Default to localhost.
RedirectURLHostname string

// Options for an authorization request.
// You can set oauth2.AccessTypeOffline and the PKCE options here.
AuthCodeOptions []oauth2.AuthCodeOption
Expand All @@ -66,6 +63,11 @@ type Config struct {
// Default to a string of random 32 bytes.
State string

// Hostname of the redirect URL.
// You can set this if your provider does not accept localhost.
// Default to localhost.
RedirectURLHostname string

// Candidates of hostname and port which the local server binds to.
// You can set port number to 0 to allocate a free port.
// If multiple addresses are given, it will try the ports in order.
Expand Down Expand Up @@ -98,37 +100,37 @@ type Config struct {
Logf func(format string, args ...interface{})
}

func (c *Config) isLocalServerHTTPS() bool {
return c.LocalServerCertFile != "" && c.LocalServerKeyFile != ""
func (cfg *Config) isLocalServerHTTPS() bool {
return cfg.LocalServerCertFile != "" && cfg.LocalServerKeyFile != ""
}

func (c *Config) validateAndSetDefaults() error {
if (c.LocalServerCertFile != "" && c.LocalServerKeyFile == "") ||
(c.LocalServerCertFile == "" && c.LocalServerKeyFile != "") {
func (cfg *Config) validateAndSetDefaults() error {
if (cfg.LocalServerCertFile != "" && cfg.LocalServerKeyFile == "") ||
(cfg.LocalServerCertFile == "" && cfg.LocalServerKeyFile != "") {
return fmt.Errorf("both LocalServerCertFile and LocalServerKeyFile must be set")
}
if c.RedirectURLHostname == "" {
c.RedirectURLHostname = "localhost"
if cfg.RedirectURLHostname == "" {
cfg.RedirectURLHostname = "localhost"
}
if c.State == "" {
s, err := oauth2params.NewState()
if cfg.State == "" {
state, err := oauth2params.NewState()
if err != nil {
return fmt.Errorf("could not generate a state parameter: %w", err)
}
c.State = s
cfg.State = state
}
if c.LocalServerMiddleware == nil {
c.LocalServerMiddleware = noopMiddleware
if cfg.LocalServerMiddleware == nil {
cfg.LocalServerMiddleware = noopMiddleware
}
if c.LocalServerSuccessHTML == "" {
c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML
if cfg.LocalServerSuccessHTML == "" {
cfg.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML
}
if (c.SuccessRedirectURL != "" && c.FailureRedirectURL == "") ||
(c.SuccessRedirectURL == "" && c.FailureRedirectURL != "") {
if (cfg.SuccessRedirectURL != "" && cfg.FailureRedirectURL == "") ||
(cfg.SuccessRedirectURL == "" && cfg.FailureRedirectURL != "") {
return fmt.Errorf("when using success and failure redirect URLs, set both URLs")
}
if c.Logf == nil {
c.Logf = func(string, ...interface{}) {}
if cfg.Logf == nil {
cfg.Logf = func(string, ...interface{}) {}
}
return nil
}
Expand All @@ -144,16 +146,16 @@ func (c *Config) validateAndSetDefaults() error {
// 4. Receive a code via an authorization response (HTTP redirect).
// 5. Exchange the code and a token.
// 6. Return the code.
func GetToken(ctx context.Context, c Config) (*oauth2.Token, error) {
if err := c.validateAndSetDefaults(); err != nil {
func GetToken(ctx context.Context, cfg Config) (*oauth2.Token, error) {
if err := cfg.validateAndSetDefaults(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
code, err := receiveCodeViaLocalServer(ctx, &c)
code, err := receiveCodeViaLocalServer(ctx, &cfg)
if err != nil {
return nil, fmt.Errorf("authorization error: %w", err)
}
c.Logf("oauth2cli: exchanging the code and token")
token, err := c.OAuth2Config.Exchange(ctx, code, c.TokenRequestOptions...)
cfg.Logf("oauth2cli: exchanging the code and token")
token, err := cfg.OAuth2Config.Exchange(ctx, code, cfg.TokenRequestOptions...)
if err != nil {
return nil, fmt.Errorf("could not exchange the code and token: %w", err)
}
Expand Down
65 changes: 36 additions & 29 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,28 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"

"github.com/int128/listener"
"golang.org/x/sync/errgroup"
)

func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
l, err := listener.New(c.LocalServerBindAddress)
func receiveCodeViaLocalServer(ctx context.Context, cfg *Config) (string, error) {
localServerListener, err := listener.New(cfg.LocalServerBindAddress)
if err != nil {
return "", fmt.Errorf("could not start a local server: %w", err)
}
defer l.Close()
c.OAuth2Config.RedirectURL = computeRedirectURL(l, c)
defer localServerListener.Close()

localServerPort := localServerListener.Addr().(*net.TCPAddr).Port
cfg.OAuth2Config.RedirectURL = constructRedirectURL(cfg, localServerPort)

respCh := make(chan *authorizationResponse)
server := http.Server{
Handler: c.LocalServerMiddleware(&localServerHandler{
config: c,
Handler: cfg.LocalServerMiddleware(&localServerHandler{
config: cfg,
respCh: respCh,
}),
}
Expand All @@ -33,15 +36,18 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
var eg errgroup.Group
eg.Go(func() error {
defer close(respCh)
c.Logf("oauth2cli: starting a server at %s", l.Addr())
defer c.Logf("oauth2cli: stopped the server")
if c.isLocalServerHTTPS() {
if err := server.ServeTLS(l, c.LocalServerCertFile, c.LocalServerKeyFile); err != nil && err != http.ErrServerClosed {
cfg.Logf("oauth2cli: starting a server at %s", localServerListener.Addr())
defer cfg.Logf("oauth2cli: stopped the server")
if cfg.isLocalServerHTTPS() {
if err := server.ServeTLS(localServerListener, cfg.LocalServerCertFile, cfg.LocalServerKeyFile); err != nil {
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return fmt.Errorf("could not start HTTPS server: %w", err)
}
return nil
}
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
if err := server.Serve(localServerListener); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("could not start HTTP server: %w", err)
}
return nil
Expand All @@ -63,22 +69,22 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
// Gracefully shutdown the server in the timeout.
// If the server has not started, Shutdown returns nil and this returns immediately.
// If Shutdown has failed, force-close the server.
c.Logf("oauth2cli: shutting down the server")
cfg.Logf("oauth2cli: shutting down the server")
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
c.Logf("oauth2cli: force-closing the server: shutdown failed: %s", err)
cfg.Logf("oauth2cli: force-closing the server: shutdown failed: %s", err)
_ = server.Close()
return nil
}
return nil
})
eg.Go(func() error {
if c.LocalServerReadyChan == nil {
if cfg.LocalServerReadyChan == nil {
return nil
}
select {
case c.LocalServerReadyChan <- c.OAuth2Config.RedirectURL:
case cfg.LocalServerReadyChan <- cfg.OAuth2Config.RedirectURL:
return nil
case <-ctx.Done():
return ctx.Err()
Expand All @@ -93,12 +99,14 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) {
return resp.code, resp.err
}

func computeRedirectURL(l net.Listener, c *Config) string {
hostPort := fmt.Sprintf("%s:%d", c.RedirectURLHostname, l.Addr().(*net.TCPAddr).Port)
if c.LocalServerCertFile != "" {
return "https://" + hostPort
func constructRedirectURL(cfg *Config, port int) string {
var redirect url.URL
redirect.Host = fmt.Sprintf("%s:%d", cfg.RedirectURLHostname, port)
redirect.Scheme = "http"
if cfg.isLocalServerHTTPS() {
redirect.Scheme = "https"
}
return "http://" + hostPort
return redirect.String()
}

type authorizationResponse struct {
Expand Down Expand Up @@ -133,7 +141,7 @@ func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *localServerHandler) handleIndex(w http.ResponseWriter, r *http.Request) {
authCodeURL := h.config.OAuth2Config.AuthCodeURL(h.config.State, h.config.AuthCodeOptions...)
h.config.Logf("oauth2cli: sending redirect to %s", authCodeURL)
http.Redirect(w, r, authCodeURL, 302)
http.Redirect(w, r, authCodeURL, http.StatusFound)
}

func (h *localServerHandler) handleCodeResponse(w http.ResponseWriter, r *http.Request) *authorizationResponse {
Expand All @@ -147,21 +155,20 @@ func (h *localServerHandler) handleCodeResponse(w http.ResponseWriter, r *http.R

if h.config.SuccessRedirectURL != "" {
http.Redirect(w, r, h.config.SuccessRedirectURL, http.StatusFound)
} else {
w.Header().Add("Content-Type", "text/html")
if _, err := fmt.Fprint(w, h.config.LocalServerSuccessHTML); err != nil {
http.Error(w, "server error", 500)
return &authorizationResponse{err: fmt.Errorf("write error: %w", err)}
}
return &authorizationResponse{code: code}
}

w.Header().Add("Content-Type", "text/html")
if _, err := fmt.Fprint(w, h.config.LocalServerSuccessHTML); err != nil {
http.Error(w, "server error", http.StatusInternalServerError)
return &authorizationResponse{err: fmt.Errorf("write error: %w", err)}
}
return &authorizationResponse{code: code}
}

func (h *localServerHandler) handleErrorResponse(w http.ResponseWriter, r *http.Request) *authorizationResponse {
q := r.URL.Query()
errorCode, errorDescription := q.Get("error"), q.Get("error_description")

h.authorizationError(w, r)
return &authorizationResponse{err: fmt.Errorf("authorization error from server: %s %s", errorCode, errorDescription)}
}
Expand All @@ -170,6 +177,6 @@ func (h *localServerHandler) authorizationError(w http.ResponseWriter, r *http.R
if h.config.FailureRedirectURL != "" {
http.Redirect(w, r, h.config.FailureRedirectURL, http.StatusFound)
} else {
http.Error(w, "authorization error", 500)
http.Error(w, "authorization error", http.StatusInternalServerError)
}
}
Loading