Skip to content

Commit

Permalink
Introduce Go context-aware Wait functions for blocking operation (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
mislav committed Dec 17, 2022
1 parent 6f7124e commit 2bcde89
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 71 deletions.
42 changes: 19 additions & 23 deletions device/device_flow.go
Expand Up @@ -13,6 +13,7 @@
package device

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -103,16 +104,16 @@ const defaultGrantType = "urn:ietf:params:oauth:grant-type:device_code"

// PollToken polls the server at pollURL until an access token is granted or denied.
//
// Deprecated: use PollTokenWithOptions.
// Deprecated: use Wait.
func PollToken(c httpClient, pollURL string, clientID string, code *CodeResponse) (*api.AccessToken, error) {
return PollTokenWithOptions(c, pollURL, PollOptions{
return Wait(context.Background(), c, pollURL, WaitOptions{
ClientID: clientID,
DeviceCode: code,
})
}

// PollOptions specifies parameters to poll the server with until authentication completes.
type PollOptions struct {
// WaitOptions specifies parameters to poll the server with until authentication completes.
type WaitOptions struct {
// ClientID is the app client ID value.
ClientID string
// ClientSecret is the app client secret value. Optional: only pass if the server requires it.
Expand All @@ -122,30 +123,28 @@ type PollOptions struct {
// GrantType overrides the default value specified by OAuth 2.0 Device Code. Optional.
GrantType string

timeNow func() time.Time
timeSleep func(time.Duration)
newPoller pollerFactory
}

// PollTokenWithOptions polls the server at uri until authorization completes.
func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.AccessToken, error) {
timeNow := opts.timeNow
if timeNow == nil {
timeNow = time.Now
}
timeSleep := opts.timeSleep
if timeSleep == nil {
timeSleep = time.Sleep
}

// Wait polls the server at uri until authorization completes.
func Wait(ctx context.Context, c httpClient, uri string, opts WaitOptions) (*api.AccessToken, error) {
checkInterval := time.Duration(opts.DeviceCode.Interval) * time.Second
expiresAt := timeNow().Add(time.Duration(opts.DeviceCode.ExpiresIn) * time.Second)
expiresIn := time.Duration(opts.DeviceCode.ExpiresIn) * time.Second
grantType := opts.GrantType
if opts.GrantType == "" {
grantType = defaultGrantType
}

makePoller := opts.newPoller
if makePoller == nil {
makePoller = newPoller
}
_, poll := makePoller(ctx, checkInterval, expiresIn)

for {
timeSleep(checkInterval)
if err := poll.Wait(); err != nil {
return nil, err
}

values := url.Values{
"client_id": {opts.ClientID},
Expand All @@ -158,6 +157,7 @@ func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.Acce
values.Add("client_secret", opts.ClientSecret)
}

// TODO: pass tctx down to the HTTP layer
resp, err := api.PostForm(c, uri, values)
if err != nil {
return nil, err
Expand All @@ -170,9 +170,5 @@ func PollTokenWithOptions(c httpClient, uri string, opts PollOptions) (*api.Acce
} else if !(errors.As(err, &apiError) && apiError.Code == "authorization_pending") {
return nil, err
}

if timeNow().After(expiresAt) {
return nil, ErrTimeout
}
}
}
72 changes: 33 additions & 39 deletions device/device_flow_test.go
Expand Up @@ -2,6 +2,8 @@ package device

import (
"bytes"
"context"
"errors"
"io/ioutil"
"net/http"
"net/url"
Expand Down Expand Up @@ -230,28 +232,16 @@ func TestRequestCode(t *testing.T) {
}

func TestPollToken(t *testing.T) {
var totalSlept time.Duration
mockSleep := func(d time.Duration) {
totalSlept += d
}
duration := func(d string) time.Duration {
res, _ := time.ParseDuration(d)
return res
}
clock := func(durations ...string) func() time.Time {
count := 0
now := time.Now()
return func() time.Time {
t := now.Add(duration(durations[count]))
count++
return t
makeFakePoller := func(maxWaits int) pollerFactory {
return func(ctx context.Context, interval, expiresIn time.Duration) (context.Context, poller) {
return ctx, &fakePoller{maxWaits: maxWaits}
}
}

type args struct {
http apiClient
url string
opts PollOptions
opts WaitOptions
}
tests := []struct {
name string
Expand Down Expand Up @@ -279,7 +269,7 @@ func TestPollToken(t *testing.T) {
},
},
url: "https://github.com/oauth",
opts: PollOptions{
opts: WaitOptions{
ClientID: "CLIENT-ID",
DeviceCode: &CodeResponse{
DeviceCode: "DEVIC",
Expand All @@ -288,14 +278,12 @@ func TestPollToken(t *testing.T) {
ExpiresIn: 99,
Interval: 5,
},
timeSleep: mockSleep,
timeNow: clock("0", "5s", "10s"),
newPoller: makeFakePoller(2),
},
},
want: &api.AccessToken{
Token: "123abc",
},
slept: duration("10s"),
posts: []postArgs{
{
url: "https://github.com/oauth",
Expand Down Expand Up @@ -328,7 +316,7 @@ func TestPollToken(t *testing.T) {
},
},
url: "https://github.com/oauth",
opts: PollOptions{
opts: WaitOptions{
ClientID: "CLIENT-ID",
ClientSecret: "SEKRIT",
GrantType: "device_code",
Expand All @@ -339,14 +327,12 @@ func TestPollToken(t *testing.T) {
ExpiresIn: 99,
Interval: 5,
},
timeSleep: mockSleep,
timeNow: clock("0", "5s", "10s"),
newPoller: makeFakePoller(1),
},
},
want: &api.AccessToken{
Token: "123abc",
},
slept: duration("5s"),
posts: []postArgs{
{
url: "https://github.com/oauth",
Expand Down Expand Up @@ -377,21 +363,19 @@ func TestPollToken(t *testing.T) {
},
},
url: "https://github.com/oauth",
opts: PollOptions{
opts: WaitOptions{
ClientID: "CLIENT-ID",
DeviceCode: &CodeResponse{
DeviceCode: "DEVIC",
UserCode: "123-abc",
VerificationURI: "http://verify.me",
ExpiresIn: 99,
ExpiresIn: 14,
Interval: 5,
},
timeSleep: mockSleep,
timeNow: clock("0", "5s", "15m"),
newPoller: makeFakePoller(2),
},
},
wantErr: "authentication timed out",
slept: duration("10s"),
wantErr: "context deadline exceeded",
posts: []postArgs{
{
url: "https://github.com/oauth",
Expand Down Expand Up @@ -424,7 +408,7 @@ func TestPollToken(t *testing.T) {
},
},
url: "https://github.com/oauth",
opts: PollOptions{
opts: WaitOptions{
ClientID: "CLIENT-ID",
DeviceCode: &CodeResponse{
DeviceCode: "DEVIC",
Expand All @@ -433,12 +417,10 @@ func TestPollToken(t *testing.T) {
ExpiresIn: 99,
Interval: 5,
},
timeSleep: mockSleep,
timeNow: clock("0", "5s"),
newPoller: makeFakePoller(1),
},
},
wantErr: "access_denied",
slept: duration("5s"),
posts: []postArgs{
{
url: "https://github.com/oauth",
Expand All @@ -453,8 +435,7 @@ func TestPollToken(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
totalSlept = 0
got, err := PollTokenWithOptions(&tt.args.http, tt.args.url, tt.args.opts)
got, err := Wait(context.Background(), &tt.args.http, tt.args.url, tt.args.opts)
if (err != nil) != (tt.wantErr != "") {
t.Errorf("PollToken() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -468,9 +449,22 @@ func TestPollToken(t *testing.T) {
if !reflect.DeepEqual(tt.args.http.calls, tt.posts) {
t.Errorf("PostForm() = %v, want %v", tt.args.http.calls, tt.posts)
}
if totalSlept != tt.slept {
t.Errorf("slept %v, wanted %v", totalSlept, tt.slept)
}
})
}
}

type fakePoller struct {
maxWaits int
count int
}

func (p *fakePoller) Wait() error {
if p.count == p.maxWaits {
return errors.New("context deadline exceeded")
}
p.count++
return nil
}

func (p *fakePoller) Cancel() {
}
3 changes: 2 additions & 1 deletion device/examples_test.go
@@ -1,6 +1,7 @@
package device

import (
"context"
"fmt"
"net/http"
"os"
Expand All @@ -22,7 +23,7 @@ func Example() {
fmt.Printf("Copy code: %s\n", code.UserCode)
fmt.Printf("then open: %s\n", code.VerificationURI)

accessToken, err := PollTokenWithOptions(httpClient, "https://github.com/login/oauth/access_token", PollOptions{
accessToken, err := Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", WaitOptions{
ClientID: clientID,
DeviceCode: code,
})
Expand Down
43 changes: 43 additions & 0 deletions device/poller.go
@@ -0,0 +1,43 @@
package device

import (
"context"
"time"
)

type poller interface {
Wait() error
Cancel()
}

type pollerFactory func(context.Context, time.Duration, time.Duration) (context.Context, poller)

func newPoller(ctx context.Context, checkInteval, expiresIn time.Duration) (context.Context, poller) {
c, cancel := context.WithTimeout(ctx, expiresIn)
return c, &intervalPoller{
ctx: c,
interval: checkInteval,
cancelFunc: cancel,
}
}

type intervalPoller struct {
ctx context.Context
interval time.Duration
cancelFunc func()
}

func (p intervalPoller) Wait() error {
t := time.NewTimer(p.interval)
select {
case <-p.ctx.Done():
t.Stop()
return p.ctx.Err()
case <-t.C:
return nil
}
}

func (p intervalPoller) Cancel() {
p.cancelFunc()
}
3 changes: 2 additions & 1 deletion oauth_device.go
Expand Up @@ -2,6 +2,7 @@ package oauth

import (
"bufio"
"context"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -58,7 +59,7 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) {
return nil, fmt.Errorf("error opening the web browser: %w", err)
}

return device.PollTokenWithOptions(httpClient, host.TokenURL, device.PollOptions{
return device.Wait(context.TODO(), httpClient, host.TokenURL, device.WaitOptions{
ClientID: oa.ClientID,
DeviceCode: code,
})
Expand Down
5 changes: 4 additions & 1 deletion oauth_webapp.go
@@ -1,6 +1,7 @@
package oauth

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -52,5 +53,7 @@ func (oa *Flow) WebAppFlow() (*api.AccessToken, error) {
httpClient = http.DefaultClient
}

return flow.AccessToken(httpClient, host.TokenURL, oa.ClientSecret)
return flow.Wait(context.TODO(), httpClient, host.TokenURL, webapp.WaitOptions{
ClientSecret: oa.ClientSecret,
})
}
5 changes: 4 additions & 1 deletion webapp/examples_test.go
@@ -1,6 +1,7 @@
package webapp

import (
"context"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -42,7 +43,9 @@ func Example() {
}

httpClient := http.DefaultClient
accessToken, err := flow.AccessToken(httpClient, "https://github.com/login/oauth/access_token", clientSecret)
accessToken, err := flow.Wait(context.TODO(), httpClient, "https://github.com/login/oauth/access_token", WaitOptions{
ClientSecret: clientSecret,
})
if err != nil {
panic(err)
}
Expand Down

0 comments on commit 2bcde89

Please sign in to comment.