Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions agent/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func ReadRequest(client *http.Client, proxyHost, backendID, requestID string, ca
// target to the inverting proxy.
type ResponseForwarder struct {
proxyWriter *io.PipeWriter
startedChan chan struct{}
responseBodyWriter *io.PipeWriter

// wroteHeader is set when WriteHeader is called. It's used to ensure a
Expand Down Expand Up @@ -205,6 +206,7 @@ func NewResponseForwarder(client *http.Client, proxyHost, backendID, requestID s
// 2. responseBodyReader, responseBodyWriter: This pipe corresponds to the response body
// from the backend target. To this pipe, we stream each read from backend target.
proxyReader, proxyWriter := io.Pipe()
startedChan := make(chan struct{}, 1)
responseBodyReader, responseBodyWriter := io.Pipe()

proxyURL := proxyHost + ResponsePath
Expand All @@ -218,6 +220,18 @@ func NewResponseForwarder(client *http.Client, proxyHost, backendID, requestID s

errChan := make(chan error, 100)
go func() {
// Wait until the response body has started being written
// (for a non-empty response) or for the response to
// be closed (for an empty response) before triggering
// the proxy request round trip.
//
// This ensures that we do not fetch the bearer token
// for the auth header until the last possible moment.
// That, in turn. prevents a race condition where the
// token expires between the header being generated
// and the request being sent to the proxy.
<-startedChan

if _, err := client.Do(proxyReq); err != nil {
errChan <- err
}
Expand Down Expand Up @@ -248,11 +262,19 @@ func NewResponseForwarder(client *http.Client, proxyHost, backendID, requestID s
},
wroteHeader: false,
proxyWriter: proxyWriter,
startedChan: startedChan,
responseBodyWriter: responseBodyWriter,
errors: errChan,
}, nil
}

func (rf *ResponseForwarder) notify() {
if rf.startedChan != nil {
rf.startedChan <- struct{}{}
rf.startedChan = nil
}
}

func (rf *ResponseForwarder) Header() http.Header {
return rf.response.Header
}
Expand All @@ -263,6 +285,7 @@ func (rf *ResponseForwarder) Write(buf []byte) (int, error) {
if !rf.wroteHeader {
rf.WriteHeader(http.StatusOK)
}
rf.notify()
count, err := rf.responseBodyWriter.Write(buf)
if err != nil {
rf.errors <- err
Expand Down Expand Up @@ -307,6 +330,7 @@ func (rf *ResponseForwarder) WriteHeader(code int) {
}

func (rf *ResponseForwarder) Close() error {
rf.notify()
var errs []error
if err := rf.responseBodyWriter.Close(); err != nil {
errs = append(errs, err)
Expand Down