From 2bcde8998b192ef783ff126b033383f842f924a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Sat, 17 Dec 2022 18:43:00 +0100 Subject: [PATCH] Introduce Go context-aware Wait functions for blocking operation (#39) --- device/device_flow.go | 42 ++++++++++------------ device/device_flow_test.go | 72 +++++++++++++++++--------------------- device/examples_test.go | 3 +- device/poller.go | 43 +++++++++++++++++++++++ oauth_device.go | 3 +- oauth_webapp.go | 5 ++- webapp/examples_test.go | 5 ++- webapp/local_server.go | 10 ++++-- webapp/webapp_flow.go | 18 ++++++++-- webapp/webapp_flow_test.go | 3 +- 10 files changed, 133 insertions(+), 71 deletions(-) create mode 100644 device/poller.go diff --git a/device/device_flow.go b/device/device_flow.go index f14eb43..82fc479 100644 --- a/device/device_flow.go +++ b/device/device_flow.go @@ -13,6 +13,7 @@ package device import ( + "context" "errors" "fmt" "net/http" @@ -103,16 +104,16 @@ const defaultGrantType = "urn:ietf:params:oauth:grant-type:device_code" // PollToken polls the server at pollURL until an access token is granted or denied. // -// Deprecated: use PollTokenWithOptions. +// Deprecated: use Wait. func PollToken(c httpClient, pollURL string, clientID string, code *CodeResponse) (*api.AccessToken, error) { - return PollTokenWithOptions(c, pollURL, PollOptions{ + return Wait(context.Background(), c, pollURL, WaitOptions{ ClientID: clientID, DeviceCode: code, }) } -// PollOptions specifies parameters to poll the server with until authentication completes. -type PollOptions struct { +// WaitOptions specifies parameters to poll the server with until authentication completes. +type WaitOptions struct { // ClientID is the app client ID value. ClientID string // ClientSecret is the app client secret value. Optional: only pass if the server requires it. @@ -122,30 +123,28 @@ type PollOptions struct { // GrantType overrides the default value specified by OAuth 2.0 Device Code. Optional. GrantType string - timeNow func() time.Time - timeSleep func(time.Duration) + newPoller pollerFactory } -// PollTokenWithOptions polls the server at uri until authorization completes. -func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.AccessToken, error) { - timeNow := opts.timeNow - if timeNow == nil { - timeNow = time.Now - } - timeSleep := opts.timeSleep - if timeSleep == nil { - timeSleep = time.Sleep - } - +// Wait polls the server at uri until authorization completes. +func Wait(ctx context.Context, c httpClient, uri string, opts WaitOptions) (*api.AccessToken, error) { checkInterval := time.Duration(opts.DeviceCode.Interval) * time.Second - expiresAt := timeNow().Add(time.Duration(opts.DeviceCode.ExpiresIn) * time.Second) + expiresIn := time.Duration(opts.DeviceCode.ExpiresIn) * time.Second grantType := opts.GrantType if opts.GrantType == "" { grantType = defaultGrantType } + makePoller := opts.newPoller + if makePoller == nil { + makePoller = newPoller + } + _, poll := makePoller(ctx, checkInterval, expiresIn) + for { - timeSleep(checkInterval) + if err := poll.Wait(); err != nil { + return nil, err + } values := url.Values{ "client_id": {opts.ClientID}, @@ -158,6 +157,7 @@ func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.Acce values.Add("client_secret", opts.ClientSecret) } + // TODO: pass tctx down to the HTTP layer resp, err := api.PostForm(c, uri, values) if err != nil { return nil, err @@ -170,9 +170,5 @@ func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.Acce } else if !(errors.As(err, &apiError) && apiError.Code == "authorization_pending") { return nil, err } - - if timeNow().After(expiresAt) { - return nil, ErrTimeout - } } } diff --git a/device/device_flow_test.go b/device/device_flow_test.go index d8fb4f3..238e14c 100644 --- a/device/device_flow_test.go +++ b/device/device_flow_test.go @@ -2,6 +2,8 @@ package device import ( "bytes" + "context" + "errors" "io/ioutil" "net/http" "net/url" @@ -230,28 +232,16 @@ func TestRequestCode(t *testing.T) { } func TestPollToken(t *testing.T) { - var totalSlept time.Duration - mockSleep := func(d time.Duration) { - totalSlept += d - } - duration := func(d string) time.Duration { - res, _ := time.ParseDuration(d) - return res - } - clock := func(durations ...string) func() time.Time { - count := 0 - now := time.Now() - return func() time.Time { - t := now.Add(duration(durations[count])) - count++ - return t + makeFakePoller := func(maxWaits int) pollerFactory { + return func(ctx context.Context, interval, expiresIn time.Duration) (context.Context, poller) { + return ctx, &fakePoller{maxWaits: maxWaits} } } type args struct { http apiClient url string - opts PollOptions + opts WaitOptions } tests := []struct { name string @@ -279,7 +269,7 @@ func TestPollToken(t *testing.T) { }, }, url: "https://github.com/oauth", - opts: PollOptions{ + opts: WaitOptions{ ClientID: "CLIENT-ID", DeviceCode: &CodeResponse{ DeviceCode: "DEVIC", @@ -288,14 +278,12 @@ func TestPollToken(t *testing.T) { ExpiresIn: 99, Interval: 5, }, - timeSleep: mockSleep, - timeNow: clock("0", "5s", "10s"), + newPoller: makeFakePoller(2), }, }, want: &api.AccessToken{ Token: "123abc", }, - slept: duration("10s"), posts: []postArgs{ { url: "https://github.com/oauth", @@ -328,7 +316,7 @@ func TestPollToken(t *testing.T) { }, }, url: "https://github.com/oauth", - opts: PollOptions{ + opts: WaitOptions{ ClientID: "CLIENT-ID", ClientSecret: "SEKRIT", GrantType: "device_code", @@ -339,14 +327,12 @@ func TestPollToken(t *testing.T) { ExpiresIn: 99, Interval: 5, }, - timeSleep: mockSleep, - timeNow: clock("0", "5s", "10s"), + newPoller: makeFakePoller(1), }, }, want: &api.AccessToken{ Token: "123abc", }, - slept: duration("5s"), posts: []postArgs{ { url: "https://github.com/oauth", @@ -377,21 +363,19 @@ func TestPollToken(t *testing.T) { }, }, url: "https://github.com/oauth", - opts: PollOptions{ + opts: WaitOptions{ ClientID: "CLIENT-ID", DeviceCode: &CodeResponse{ DeviceCode: "DEVIC", UserCode: "123-abc", VerificationURI: "http://verify.me", - ExpiresIn: 99, + ExpiresIn: 14, Interval: 5, }, - timeSleep: mockSleep, - timeNow: clock("0", "5s", "15m"), + newPoller: makeFakePoller(2), }, }, - wantErr: "authentication timed out", - slept: duration("10s"), + wantErr: "context deadline exceeded", posts: []postArgs{ { url: "https://github.com/oauth", @@ -424,7 +408,7 @@ func TestPollToken(t *testing.T) { }, }, url: "https://github.com/oauth", - opts: PollOptions{ + opts: WaitOptions{ ClientID: "CLIENT-ID", DeviceCode: &CodeResponse{ DeviceCode: "DEVIC", @@ -433,12 +417,10 @@ func TestPollToken(t *testing.T) { ExpiresIn: 99, Interval: 5, }, - timeSleep: mockSleep, - timeNow: clock("0", "5s"), + newPoller: makeFakePoller(1), }, }, wantErr: "access_denied", - slept: duration("5s"), posts: []postArgs{ { url: "https://github.com/oauth", @@ -453,8 +435,7 @@ func TestPollToken(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - totalSlept = 0 - got, err := PollTokenWithOptions(&tt.args.http, tt.args.url, tt.args.opts) + got, err := Wait(context.Background(), &tt.args.http, tt.args.url, tt.args.opts) if (err != nil) != (tt.wantErr != "") { t.Errorf("PollToken() error = %v, wantErr %v", err, tt.wantErr) return @@ -468,9 +449,22 @@ func TestPollToken(t *testing.T) { if !reflect.DeepEqual(tt.args.http.calls, tt.posts) { t.Errorf("PostForm() = %v, want %v", tt.args.http.calls, tt.posts) } - if totalSlept != tt.slept { - t.Errorf("slept %v, wanted %v", totalSlept, tt.slept) - } }) } } + +type fakePoller struct { + maxWaits int + count int +} + +func (p *fakePoller) Wait() error { + if p.count == p.maxWaits { + return errors.New("context deadline exceeded") + } + p.count++ + return nil +} + +func (p *fakePoller) Cancel() { +} diff --git a/device/examples_test.go b/device/examples_test.go index 778a5ac..2d78910 100644 --- a/device/examples_test.go +++ b/device/examples_test.go @@ -1,6 +1,7 @@ package device import ( + "context" "fmt" "net/http" "os" @@ -22,7 +23,7 @@ func Example() { fmt.Printf("Copy code: %s\n", code.UserCode) fmt.Printf("then open: %s\n", code.VerificationURI) - accessToken, err := PollTokenWithOptions(httpClient, "https://github.com/login/oauth/access_token", PollOptions{ + accessToken, err := Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", WaitOptions{ ClientID: clientID, DeviceCode: code, }) diff --git a/device/poller.go b/device/poller.go new file mode 100644 index 0000000..06a2e9d --- /dev/null +++ b/device/poller.go @@ -0,0 +1,43 @@ +package device + +import ( + "context" + "time" +) + +type poller interface { + Wait() error + Cancel() +} + +type pollerFactory func(context.Context, time.Duration, time.Duration) (context.Context, poller) + +func newPoller(ctx context.Context, checkInteval, expiresIn time.Duration) (context.Context, poller) { + c, cancel := context.WithTimeout(ctx, expiresIn) + return c, &intervalPoller{ + ctx: c, + interval: checkInteval, + cancelFunc: cancel, + } +} + +type intervalPoller struct { + ctx context.Context + interval time.Duration + cancelFunc func() +} + +func (p intervalPoller) Wait() error { + t := time.NewTimer(p.interval) + select { + case <-p.ctx.Done(): + t.Stop() + return p.ctx.Err() + case <-t.C: + return nil + } +} + +func (p intervalPoller) Cancel() { + p.cancelFunc() +} diff --git a/oauth_device.go b/oauth_device.go index 139216a..d4615ef 100644 --- a/oauth_device.go +++ b/oauth_device.go @@ -2,6 +2,7 @@ package oauth import ( "bufio" + "context" "fmt" "io" "net/http" @@ -58,7 +59,7 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) { return nil, fmt.Errorf("error opening the web browser: %w", err) } - return device.PollTokenWithOptions(httpClient, host.TokenURL, device.PollOptions{ + return device.Wait(context.TODO(), httpClient, host.TokenURL, device.WaitOptions{ ClientID: oa.ClientID, DeviceCode: code, }) diff --git a/oauth_webapp.go b/oauth_webapp.go index 5484aee..e448715 100644 --- a/oauth_webapp.go +++ b/oauth_webapp.go @@ -1,6 +1,7 @@ package oauth import ( + "context" "fmt" "net/http" @@ -52,5 +53,7 @@ func (oa *Flow) WebAppFlow() (*api.AccessToken, error) { httpClient = http.DefaultClient } - return flow.AccessToken(httpClient, host.TokenURL, oa.ClientSecret) + return flow.Wait(context.TODO(), httpClient, host.TokenURL, webapp.WaitOptions{ + ClientSecret: oa.ClientSecret, + }) } diff --git a/webapp/examples_test.go b/webapp/examples_test.go index 2793cae..9e6ad91 100644 --- a/webapp/examples_test.go +++ b/webapp/examples_test.go @@ -1,6 +1,7 @@ package webapp import ( + "context" "fmt" "net/http" "os" @@ -42,7 +43,9 @@ func Example() { } httpClient := http.DefaultClient - accessToken, err := flow.AccessToken(httpClient, "https://github.com/login/oauth/access_token", clientSecret) + accessToken, err := flow.Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", WaitOptions{ + ClientSecret: clientSecret, + }) if err != nil { panic(err) } diff --git a/webapp/local_server.go b/webapp/local_server.go index d978d65..c895dfe 100644 --- a/webapp/local_server.go +++ b/webapp/local_server.go @@ -1,6 +1,7 @@ package webapp import ( + "context" "fmt" "io" "net" @@ -46,8 +47,13 @@ func (s *localServer) Serve() error { return http.Serve(s.listener, s) } -func (s *localServer) WaitForCode() (CodeResponse, error) { - return <-s.resultChan, nil +func (s *localServer) WaitForCode(ctx context.Context) (CodeResponse, error) { + select { + case <-ctx.Done(): + return CodeResponse{}, ctx.Err() + case code := <-s.resultChan: + return code, nil + } } // ServeHTTP implements http.Handler. diff --git a/webapp/webapp_flow.go b/webapp/webapp_flow.go index 9e86df0..49c6551 100644 --- a/webapp/webapp_flow.go +++ b/webapp/webapp_flow.go @@ -3,6 +3,7 @@ package webapp import ( + "context" "crypto/rand" "encoding/hex" "errors" @@ -85,8 +86,21 @@ func (flow *Flow) StartServer(writeSuccess func(io.Writer)) error { } // AccessToken blocks until the browser flow has completed and returns the access token. +// +// Deprecated: use Wait. func (flow *Flow) AccessToken(c httpClient, tokenURL, clientSecret string) (*api.AccessToken, error) { - code, err := flow.server.WaitForCode() + return flow.Wait(context.Background(), c, tokenURL, WaitOptions{ClientSecret: clientSecret}) +} + +// WaitOptions specifies parameters to exchange the access token for. +type WaitOptions struct { + // ClientSecret is the app client secret value. + ClientSecret string +} + +// Wait blocks until the browser flow has completed and returns the access token. +func (flow *Flow) Wait(ctx context.Context, c httpClient, tokenURL string, opts WaitOptions) (*api.AccessToken, error) { + code, err := flow.server.WaitForCode(ctx) if err != nil { return nil, err } @@ -97,7 +111,7 @@ func (flow *Flow) AccessToken(c httpClient, tokenURL, clientSecret string) (*api resp, err := api.PostForm(c, tokenURL, url.Values{ "client_id": {flow.clientID}, - "client_secret": {clientSecret}, + "client_secret": {opts.ClientSecret}, "code": {code.Code}, "state": {flow.state}, }) diff --git a/webapp/webapp_flow_test.go b/webapp/webapp_flow_test.go index 8814b06..f4e4355 100644 --- a/webapp/webapp_flow_test.go +++ b/webapp/webapp_flow_test.go @@ -2,6 +2,7 @@ package webapp import ( "bytes" + "context" "io/ioutil" "net" "net/http" @@ -132,7 +133,7 @@ func TestFlow_AccessToken(t *testing.T) { } }() - token, err := flow.AccessToken(client, "https://github.com/access_token", "OAUTH-SEKRIT") + token, err := flow.Wait(context.Background(), client, "https://github.com/access_token", WaitOptions{ClientSecret: "OAUTH-SEKRIT"}) if err != nil { t.Fatalf("AccessToken() error: %v", err) }