Skip to content

Commit

Permalink
apiclient: split auth_key, auth_retry, auth_jwt
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Jan 16, 2024
1 parent 24b5e8f commit aed4593
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 161 deletions.
161 changes: 0 additions & 161 deletions pkg/apiclient/auth.go → pkg/apiclient/auth_jwt.go
Expand Up @@ -3,10 +3,8 @@ package apiclient
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
Expand All @@ -16,143 +14,9 @@ import (
"github.com/go-openapi/strfmt"
log "github.com/sirupsen/logrus"

"github.com/crowdsecurity/crowdsec/pkg/fflag"
"github.com/crowdsecurity/crowdsec/pkg/models"
)

type APIKeyTransport struct {
APIKey string
// Transport is the underlying HTTP transport to use when making requests.
// It will default to http.DefaultTransport if nil.
Transport http.RoundTripper
URL *url.URL
VersionPrefix string
UserAgent string
}

// RoundTrip implements the RoundTripper interface.
func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.APIKey == "" {
return nil, errors.New("APIKey is empty")
}

// We must make a copy of the Request so
// that we don't modify the Request we were given. This is required by the
// specification of http.RoundTripper.
req = cloneRequest(req)
req.Header.Add("X-Api-Key", t.APIKey)

if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent)
}

log.Debugf("req-api: %s %s", req.Method, req.URL.String())

if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("auth-api request: %s", string(dump))
}

// Make the HTTP request.
resp, err := t.transport().RoundTrip(req)
if err != nil {
log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)

return resp, err
}

if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("auth-api response: %s", string(dump))
}

log.Debugf("resp-api: http %d", resp.StatusCode)

return resp, err
}

func (t *APIKeyTransport) Client() *http.Client {
return &http.Client{Transport: t}
}

func (t *APIKeyTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}

return http.DefaultTransport
}

type retryRoundTripper struct {
next http.RoundTripper
maxAttempts int
retryStatusCodes []int
withBackOff bool
onBeforeRequest func(attempt int)
}

func (r retryRoundTripper) ShouldRetry(statusCode int) bool {
for _, code := range r.retryStatusCodes {
if code == statusCode {
return true
}
}

return false
}

func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var (
resp *http.Response
err error
)

backoff := 0
maxAttempts := r.maxAttempts

if fflag.DisableHttpRetryBackoff.IsEnabled() {
maxAttempts = 1
}

for i := 0; i < maxAttempts; i++ {
if i > 0 {
if r.withBackOff {
//nolint:gosec
backoff += 10 + rand.Intn(20)
}

log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)

select {
case <-req.Context().Done():
return nil, req.Context().Err()
case <-time.After(time.Duration(backoff) * time.Second):
}
}

if r.onBeforeRequest != nil {
r.onBeforeRequest(i)
}

clonedReq := cloneRequest(req)

resp, err = r.next.RoundTrip(clonedReq)
if err != nil {
if left := maxAttempts - i - 1; left > 0 {
log.Errorf("error while performing request: %s; %d retries left", err, left)
}

continue
}

if !r.ShouldRetry(resp.StatusCode) {
return resp, nil
}
}

return resp, err
}

type JWTTransport struct {
MachineID *string
Password *strfmt.Password
Expand Down Expand Up @@ -351,28 +215,3 @@ func (t *JWTTransport) transport() http.RoundTripper {
},
}
}

// cloneRequest returns a clone of the provided *http.Request. The clone is a
// shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))

for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}

if r.Body != nil {
var b bytes.Buffer

b.ReadFrom(r.Body)

r.Body = io.NopCloser(&b)
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
}

return r2
}
73 changes: 73 additions & 0 deletions pkg/apiclient/auth_key.go
@@ -0,0 +1,73 @@
package apiclient

import (
"errors"
"net/http"
"net/http/httputil"
"net/url"

log "github.com/sirupsen/logrus"
)

type APIKeyTransport struct {
APIKey string
// Transport is the underlying HTTP transport to use when making requests.
// It will default to http.DefaultTransport if nil.
Transport http.RoundTripper
URL *url.URL
VersionPrefix string
UserAgent string
}

// RoundTrip implements the RoundTripper interface.
func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.APIKey == "" {
return nil, errors.New("APIKey is empty")
}

Check warning on line 26 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L23-L26

Added lines #L23 - L26 were not covered by tests

// We must make a copy of the Request so
// that we don't modify the Request we were given. This is required by the
// specification of http.RoundTripper.
req = cloneRequest(req)
req.Header.Add("X-Api-Key", t.APIKey)

if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent)
}

Check warning on line 36 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L31-L36

Added lines #L31 - L36 were not covered by tests

log.Debugf("req-api: %s %s", req.Method, req.URL.String())

if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("auth-api request: %s", string(dump))
}

Check warning on line 43 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L38-L43

Added lines #L38 - L43 were not covered by tests

// Make the HTTP request.
resp, err := t.transport().RoundTrip(req)
if err != nil {
log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)

return resp, err
}

Check warning on line 51 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L46-L51

Added lines #L46 - L51 were not covered by tests

if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("auth-api response: %s", string(dump))
}

Check warning on line 56 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L53-L56

Added lines #L53 - L56 were not covered by tests

log.Debugf("resp-api: http %d", resp.StatusCode)

return resp, err

Check warning on line 60 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L58-L60

Added lines #L58 - L60 were not covered by tests
}

func (t *APIKeyTransport) Client() *http.Client {
return &http.Client{Transport: t}

Check warning on line 64 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L63-L64

Added lines #L63 - L64 were not covered by tests
}

func (t *APIKeyTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}

Check warning on line 70 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L67-L70

Added lines #L67 - L70 were not covered by tests

return http.DefaultTransport

Check warning on line 72 in pkg/apiclient/auth_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_key.go#L72

Added line #L72 was not covered by tests
}
File renamed without changes.
81 changes: 81 additions & 0 deletions pkg/apiclient/auth_retry.go
@@ -0,0 +1,81 @@
package apiclient

import (
"math/rand"
"net/http"
"time"

log "github.com/sirupsen/logrus"

"github.com/crowdsecurity/crowdsec/pkg/fflag"
)

type retryRoundTripper struct {
next http.RoundTripper
maxAttempts int
retryStatusCodes []int
withBackOff bool
onBeforeRequest func(attempt int)
}

func (r retryRoundTripper) ShouldRetry(statusCode int) bool {
for _, code := range r.retryStatusCodes {
if code == statusCode {
return true
}

Check warning on line 25 in pkg/apiclient/auth_retry.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_retry.go#L24-L25

Added lines #L24 - L25 were not covered by tests
}

return false
}

func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var (
resp *http.Response
err error
)

backoff := 0
maxAttempts := r.maxAttempts

if fflag.DisableHttpRetryBackoff.IsEnabled() {
maxAttempts = 1
}

for i := 0; i < maxAttempts; i++ {
if i > 0 {
if r.withBackOff {
//nolint:gosec
backoff += 10 + rand.Intn(20)
}

Check warning on line 49 in pkg/apiclient/auth_retry.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_retry.go#L46-L49

Added lines #L46 - L49 were not covered by tests

log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)

select {
case <-req.Context().Done():
return nil, req.Context().Err()
case <-time.After(time.Duration(backoff) * time.Second):

Check warning on line 56 in pkg/apiclient/auth_retry.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_retry.go#L51-L56

Added lines #L51 - L56 were not covered by tests
}
}

if r.onBeforeRequest != nil {
r.onBeforeRequest(i)
}

clonedReq := cloneRequest(req)

resp, err = r.next.RoundTrip(clonedReq)
if err != nil {
if left := maxAttempts - i - 1; left > 0 {
log.Errorf("error while performing request: %s; %d retries left", err, left)
}

Check warning on line 70 in pkg/apiclient/auth_retry.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiclient/auth_retry.go#L69-L70

Added lines #L69 - L70 were not covered by tests

continue
}

if !r.ShouldRetry(resp.StatusCode) {
return resp, nil
}
}

return resp, err
}
32 changes: 32 additions & 0 deletions pkg/apiclient/clone.go
@@ -0,0 +1,32 @@
package apiclient

import (
"bytes"
"io"
"net/http"
)

// cloneRequest returns a clone of the provided *http.Request. The clone is a
// shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))

for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}

if r.Body != nil {
var b bytes.Buffer

b.ReadFrom(r.Body)

r.Body = io.NopCloser(&b)
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
}

return r2
}

0 comments on commit aed4593

Please sign in to comment.