diff --git a/agent/utils/utils.go b/agent/utils/utils.go index 3455eb5..0f42c43 100644 --- a/agent/utils/utils.go +++ b/agent/utils/utils.go @@ -176,6 +176,7 @@ func ReadRequest(client *http.Client, proxyHost, backendID, requestID string, ca // target to the inverting proxy. type ResponseForwarder struct { proxyWriter *io.PipeWriter + startedChan chan struct{} responseBodyWriter *io.PipeWriter // wroteHeader is set when WriteHeader is called. It's used to ensure a @@ -205,6 +206,7 @@ func NewResponseForwarder(client *http.Client, proxyHost, backendID, requestID s // 2. responseBodyReader, responseBodyWriter: This pipe corresponds to the response body // from the backend target. To this pipe, we stream each read from backend target. proxyReader, proxyWriter := io.Pipe() + startedChan := make(chan struct{}, 1) responseBodyReader, responseBodyWriter := io.Pipe() proxyURL := proxyHost + ResponsePath @@ -218,6 +220,18 @@ func NewResponseForwarder(client *http.Client, proxyHost, backendID, requestID s errChan := make(chan error, 100) go func() { + // Wait until the response body has started being written + // (for a non-empty response) or for the response to + // be closed (for an empty response) before triggering + // the proxy request round trip. + // + // This ensures that we do not fetch the bearer token + // for the auth header until the last possible moment. + // That, in turn. prevents a race condition where the + // token expires between the header being generated + // and the request being sent to the proxy. + <-startedChan + if _, err := client.Do(proxyReq); err != nil { errChan <- err } @@ -248,11 +262,19 @@ func NewResponseForwarder(client *http.Client, proxyHost, backendID, requestID s }, wroteHeader: false, proxyWriter: proxyWriter, + startedChan: startedChan, responseBodyWriter: responseBodyWriter, errors: errChan, }, nil } +func (rf *ResponseForwarder) notify() { + if rf.startedChan != nil { + rf.startedChan <- struct{}{} + rf.startedChan = nil + } +} + func (rf *ResponseForwarder) Header() http.Header { return rf.response.Header } @@ -263,6 +285,7 @@ func (rf *ResponseForwarder) Write(buf []byte) (int, error) { if !rf.wroteHeader { rf.WriteHeader(http.StatusOK) } + rf.notify() count, err := rf.responseBodyWriter.Write(buf) if err != nil { rf.errors <- err @@ -307,6 +330,7 @@ func (rf *ResponseForwarder) WriteHeader(code int) { } func (rf *ResponseForwarder) Close() error { + rf.notify() var errs []error if err := rf.responseBodyWriter.Close(); err != nil { errs = append(errs, err)