Skip to content

Commit

Permalink
Merge pull request #25 from martinohansen/martin/token-handler
Browse files Browse the repository at this point in the history
fix: deadlock in RoundTrip
  • Loading branch information
frieser authored Mar 7, 2024
2 parents 9afe5bb + 4dafb8e commit b9f2fcc
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 34 deletions.
91 changes: 69 additions & 22 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,50 +18,97 @@ type Client struct {
expiration time.Time
token *Token
m *sync.Mutex
stopChan chan struct{}
}

type refreshTokenTransport struct {
type Transport struct {
rt http.RoundTripper
cli *Client
}

func (t refreshTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) {
var err error
func (c *Client) refreshTokenIfNeeded() error {
c.m.Lock()
defer c.m.Unlock()

if time.Now().Add(time.Minute).Before(c.expiration) {
return nil
} else {
// Refresh the token if its expiration is less than a minute away
newToken, err := c.refreshToken(c.token.Refresh)
if err != nil {
return err
}
c.token = newToken
c.expiration = time.Now().Add(time.Duration(newToken.RefreshExpires-60) * time.Second)
return nil
}
}

func (c *Client) StartTokenHandler() {
c.stopChan = make(chan struct{})

// Initialize the first token and start the token handler
token, err := c.newToken()
if err != nil {
panic("Failed to get initial token: " + err.Error())
}
c.token = token

go func() {
for {
timeToWait := time.Until(c.expiration) - time.Minute
if timeToWait < 0 {
// If the token is already expired, try to refresh immediately
timeToWait = 0
}

select {
case <-c.stopChan:
return
case <-time.After(timeToWait):
if err := c.refreshTokenIfNeeded(); err != nil {
// TODO(Martin): add retry logic
panic("Failed to refresh token: " + err.Error())
}
}
}
}()
}

func (c *Client) StopTokenHandler() {
close(c.stopChan)
}

func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "https"
req.URL.Host = baseUrl
req.URL.Path = strings.Join([]string{apiPath, req.URL.Path}, "/")

req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")

t.cli.m.Lock()

if t.cli.expiration.Before(time.Now()) {
t.cli.token, err = t.cli.refreshToken(t.cli.token.Refresh)

if err != nil {
return nil, err
}
t.cli.expiration = t.cli.expiration.Add(time.Duration(t.cli.token.RefreshExpires-60) * time.Second)
// Add the access token to the request if it exists
if t.cli.token != nil {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.cli.token.Access))
}
t.cli.m.Unlock()
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.cli.token.Access))

return t.rt.RoundTrip(req)
}

// NewClient creates a new Nordigen client that handles token refreshes and adds
// the necessary headers, host, and path to all requests.
func NewClient(secretId, secretKey string) (*Client, error) {
var err error
c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{},
secretId: secretId,
secretKey: secretKey,
}

c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{}}
c.token, err = c.newToken(secretId, secretKey)
// Add transport to handle headers, host and path for all requests
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}

if err != nil {
return nil, err
}
c.c.Transport = refreshTokenTransport{rt: http.DefaultTransport, cli: c}
c.expiration = time.Now().Add(time.Duration(c.token.AccessExpires-60) * time.Second)
// Start token handler
c.StartTokenHandler()
defer c.StopTokenHandler()

return c, nil
}
30 changes: 30 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package nordigen

import (
"os"
"testing"
"time"
)

// TestClientTokenRefresh should do a successful token refresh. We force this by
// setting the expiration to a time in the past and then calling any method.
// This test will only run if you have a valid secretId and secretKey in your
// environment.
func TestClientTokenRefresh(t *testing.T) {
id, id_exists := os.LookupEnv("NORDIGEN_SECRET_ID")
key, key_exists := os.LookupEnv("NORDIGEN_SECRET_KEY")
if !id_exists || !key_exists {
t.Skip("NORDIGEN_SECRET_ID and NORDIGEN_SECRET_KEY not set")
}

c, err := NewClient(id, key)
if err != nil {
t.Fatalf("NewClient: %s", err)
}

c.expiration = time.Now().Add(-time.Hour)
_, err = c.ListRequisitions()
if err != nil {
t.Fatalf("ListRequisitions: %s", err)
}
}
18 changes: 6 additions & 12 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
Expand All @@ -26,22 +25,17 @@ const tokenPath = "token"
const tokenNewPath = "new/"
const tokenRefreshPath = "refresh"

func (c Client) newToken(secretId, secretKey string) (*Token, error) {
func (c Client) newToken() (*Token, error) {
req := http.Request{
Method: http.MethodPost,
URL: &url.URL{
Scheme: "https",
Host: baseUrl,
Path: strings.Join([]string{apiPath, tokenPath, tokenNewPath}, "/"),
Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"),
},
}
req.Header = http.Header{}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")

data, err := json.Marshal(Secret{
SecretId: secretId,
AccessId: secretKey,
SecretId: c.secretId,
AccessId: c.secretKey,
})
if err != nil {
return nil, err
Expand All @@ -52,7 +46,7 @@ func (c Client) newToken(secretId, secretKey string) (*Token, error) {
if err != nil {
return nil, err
}
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)

if err != nil {
return nil, err
Expand Down Expand Up @@ -89,7 +83,7 @@ func (c Client) refreshToken(refresh string) (*Token, error) {
if err != nil {
return nil, err
}
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)

if err != nil {
return nil, err
Expand Down

0 comments on commit b9f2fcc

Please sign in to comment.