Skip to content

Commit

Permalink
make request for IP restricted backup (from AWS Lambda) in parallel w…
Browse files Browse the repository at this point in the history
…ith s3 backup request
  • Loading branch information
danenania committed Sep 14, 2018
1 parent 8c6cc7b commit bada015
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 55 deletions.
147 changes: 106 additions & 41 deletions fetch/fetch.go
@@ -1,11 +1,13 @@
package fetch

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
Expand All @@ -15,11 +17,9 @@ import (
"time"

"github.com/certifi/gocertifi"
"github.com/davecgh/go-spew/spew"
"github.com/envkey/envkey-fetch/cache"
"github.com/envkey/envkey-fetch/parser"
"github.com/envkey/envkey-fetch/version"
"github.com/envkey/myhttp"
multierror "github.com/hashicorp/go-multierror"
)

Expand All @@ -36,15 +36,27 @@ var DefaultHost = "env.envkey.com"
var BackupHost = "s3-eu-west-1.amazonaws.com/envkey-backup/envs"
var BackupHostRestricted = "me66hg5t17.execute-api.eu-west-1.amazonaws.com/default/envBackup"
var ApiVersion = 1
var HttpGetter = myhttp.New(time.Second * 6)

var Client *http.Client

type httpChannelResponse struct {
response *http.Response
url string
}

type httpChannelErr struct {
err error
url string
}

func Fetch(envkey string, options FetchOptions) (string, error) {
if len(strings.Split(envkey, "-")) < 2 {
return "", errors.New("ENVKEY invalid")
}

if options.TimeoutSeconds != 6.0 {
HttpGetter = myhttp.New(time.Second * time.Duration(options.TimeoutSeconds))
// may be initalized already when mocking for tests
if Client == nil {
InitHttpClient(options.TimeoutSeconds)
}

var fetchCache *cache.Cache
Expand Down Expand Up @@ -116,25 +128,78 @@ func UrlWithLoggingParams(baseUrl string, options FetchOptions) string {
)
}

func httpGet(url string) (*http.Response, error) {
fmt.Println("httpGet:")
fmt.Println(url)
func InitHttpClient(timeoutSeconds float64) {
// http.Client.Get reuses the transport. this should be created once.
tp := http.Transport{}
to := time.Second * time.Duration(timeoutSeconds)

r, err := HttpGetter.Get(url)
tp.DialContext = (&net.Dialer{
Timeout: to,
}).DialContext

// if error caused by missing root certificates, pull in gocertifi certs (which come from Mozilla) and try again with those
if err != nil && strings.Contains(err.Error(), "x509: failed to load system roots") {
certPool, err := gocertifi.CACerts()
if err != nil {
return nil, err
tp.TLSHandshakeTimeout = to
tp.ResponseHeaderTimeout = to
tp.ExpectContinueTimeout = to

Client = &http.Client{
Transport: &tp,
}
}

func httpExecRequest(
req *http.Request,
respChan chan httpChannelResponse,
errChan chan httpChannelErr,
) {
resp, err := Client.Do(req)
if err == nil {
respChan <- httpChannelResponse{resp, req.URL.String()}
} else {
// if error caused by missing root certificates, pull in gocertifi certs (which come from Mozilla) and try again with those
if strings.Contains(err.Error(), "x509: failed to load system roots") {
certPool, certPoolErr := gocertifi.CACerts()
if certPoolErr != nil {
errChan <- httpChannelErr{multierror.Append(err, certPoolErr), req.URL.String()}
return
}
Client.Transport.(*http.Transport).TLSClientConfig = &tls.Config{RootCAs: certPool}
httpExecRequest(req, respChan, errChan)
} else {
errChan <- httpChannelErr{err, req.URL.String()}
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{RootCAs: certPool},
}
}

func httpGetAsync(
url string,
ctx context.Context,
respChan chan httpChannelResponse,
errChan chan httpChannelErr,
) {
req, err := http.NewRequest("GET", url, nil)

if err != nil {
errChan <- httpChannelErr{err, url}
return
}

req = req.WithContext(ctx)

go httpExecRequest(req, respChan, errChan)
}

func httpGet(url string) (*http.Response, error) {
respChan, errChan := make(chan httpChannelResponse), make(chan httpChannelErr)

httpGetAsync(url, context.Background(), respChan, errChan)

for {
select {
case channelResp := <-respChan:
return channelResp.response, nil
case channelErr := <-errChan:
return nil, channelErr.err
}
HttpGetter.Client.Transport = transport
return HttpGetter.Get(url)
} else {
return r, err
}
}

Expand Down Expand Up @@ -215,30 +280,36 @@ func fetchBackup(envkeyParam string, options FetchOptions) (*http.Response, erro
fmt.Fprintf(os.Stderr, "Attempting to load encrypted config from backup urls: %s\n", backupUrls)
}

success, failed := make(chan *http.Response), make(chan error)
respChan, errChan := make(chan httpChannelResponse), make(chan httpChannelErr)

cancelFnByUrl := map[string]context.CancelFunc{}

for _, backupUrl := range backupUrls {
go func(backupUrl string) {
urlWithParams := UrlWithLoggingParams(backupUrl, options)
r, err := httpGet(urlWithParams)
logRequestIfVerbose(urlWithParams, options, err, r)
if err == nil {
success <- r
} else {
failed <- err
}
}(backupUrl)
ctx, cancel := context.WithCancel(context.Background())
urlWithParams := UrlWithLoggingParams(backupUrl, options)
cancelFnByUrl[urlWithParams] = cancel
httpGetAsync(urlWithParams, ctx, respChan, errChan)
}

var err error

for {
r, any := <-success
channelResp, any := <-respChan

if any {
return r, nil
logRequestIfVerbose(channelResp.url, options, nil, channelResp.response)

// cancel other requests
for backupUrl, cancel := range cancelFnByUrl {
if backupUrl != channelResp.url {
cancel()
}
}

return channelResp.response, nil
} else {
err = multierror.Append(err, <-failed)
channelErr := <-errChan
logRequestIfVerbose(channelErr.url, options, channelErr.err, nil)
err = multierror.Append(err, channelErr.err)
}
}

Expand Down Expand Up @@ -271,12 +342,6 @@ func getJson(envkeyHost string, envkeyParam string, options FetchOptions, respon
if r != nil {
defer r.Body.Close()
}

fmt.Println("backupFetchErr")
spew.Dump(backupFetchErr)
fmt.Println("r.StatusCode")
spew.Dump(r.StatusCode)

}
}

Expand Down
24 changes: 10 additions & 14 deletions fetch/fetch_test/fetch_test.go
Expand Up @@ -74,21 +74,19 @@ var fetchTests = []struct {

func TestFetch(t *testing.T) {
assert := assert.New(t)
httpmock.ActivateNonDefault(fetch.HttpGetter.Client)
fetch.InitHttpClient(2.0)
httpmock.ActivateNonDefault(fetch.Client)
defer httpmock.DeactivateAndReset()

// Caching enabled
opts := fetch.FetchOptions{true, "", "envkey-fetch", version.Version, false, 2.0}

// Caching enabled
for _, test := range fetchTests {
var envkeyParam = strings.Split(test.envkey, "-")[0]

baseUrl := (test.protocol + "://" + test.host + "/v" + strconv.Itoa(fetch.ApiVersion) + "/" + envkeyParam)
opts := fetch.FetchOptions{true, "", "envkey-fetch", version.Version, false, 6.0}
url := fetch.UrlWithLoggingParams(baseUrl, opts)

fmt.Println("TestFetch url:")
fmt.Println(url)

httpmock.RegisterResponder(
"GET",
url,
Expand Down Expand Up @@ -119,7 +117,7 @@ func TestFetch(t *testing.T) {
assert.NotNil(err, "Should not cache the response.")
}

res, err = fetch.Fetch(test.envkey, fetch.FetchOptions{false, "", "envkey-fetch", version.Version, false, 6.0})
res, err = fetch.Fetch(test.envkey, fetch.FetchOptions{false, "", "envkey-fetch", version.Version, false, 2.0})

// With caching disabled
if test.expectErr {
Expand All @@ -144,31 +142,29 @@ func TestLiveFetch(t *testing.T) {
assert := assert.New(t)

// Test valid
validRes, err := fetch.Fetch(VALID_LIVE_ENVKEY, fetch.FetchOptions{false, "", "envkey-fetch", version.Version, false, 6.0})
validRes, err := fetch.Fetch(VALID_LIVE_ENVKEY, fetch.FetchOptions{false, "", "envkey-fetch", version.Version, false, 2.0})
assert.Nil(err)
assert.Equal("{\"TEST\":\"it\",\"TEST_2\":\"works!\",\"TEST_INJECTION\":\"'$(uname)\",\"TEST_SINGLE_QUOTES\":\"this' is ok\",\"TEST_SPACES\":\"it does work!\"}", validRes)

// Test invalid
invalidRes, err := fetch.Fetch(INVALID_LIVE_ENVKEY, fetch.FetchOptions{false, "", "envkey-fetch", version.Version, false, 6.0})
invalidRes, err := fetch.Fetch(INVALID_LIVE_ENVKEY, fetch.FetchOptions{false, "", "envkey-fetch", version.Version, false, 2.0})
assert.NotNil(err)
assert.Equal("ENVKEY invalid", string(err.Error()))
assert.Equal("", invalidRes)
}

func TestBackup(t *testing.T) {
assert := assert.New(t)
httpmock.ActivateNonDefault(fetch.HttpGetter.Client)
fetch.InitHttpClient(2.0)
httpmock.ActivateNonDefault(fetch.Client)
defer httpmock.DeactivateAndReset()

// Test with backup
fetch.DefaultHost = "localhost:61034"
opts := fetch.FetchOptions{false, "", "envkey-fetch", version.Version, false, 6.0}
opts := fetch.FetchOptions{false, "", "envkey-fetch", version.Version, false, 2.0}
url := fetch.UrlWithLoggingParams("https://"+fetch.BackupHost+"/v"+strconv.Itoa(fetch.ApiVersion)+"/validkey", opts)
restrictedUrl := fetch.UrlWithLoggingParams(fmt.Sprintf("%s?v=%s&id=%s", ("https://"+fetch.BackupHostRestricted), strconv.Itoa(fetch.ApiVersion), "validkey"), opts)

fmt.Println(url)
fmt.Println(restrictedUrl)

httpmock.RegisterResponder(
"GET",
url,
Expand Down

0 comments on commit bada015

Please sign in to comment.