Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions client/content_negotiation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func runBuildHTTPCases(t *testing.T, cases iter.Seq[buildHTTPCase]) {

func runBuildHTTPCase(tc buildHTTPCase) func(*testing.T) {
return func(t *testing.T) {
ctx := t.Context()
method := tc.method
if method == "" {
method = http.MethodPost
Expand All @@ -110,7 +111,9 @@ func runBuildHTTPCase(tc buildHTTPCase) func(*testing.T) {

r := request.New(method, "/", writer)
r.SetConsumes(tc.consumes)
req, err := r.BuildHTTP(tc.mediaType, "/", producers, strfmt.Default, nil)
req, cancel, err := r.BuildHTTPContext(ctx, tc.mediaType, "/", producers, strfmt.Default, nil)
defer cancel()

if tc.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErr)
Expand Down Expand Up @@ -395,7 +398,7 @@ func payloadStructCases() iter.Seq[buildHTTPCase] {
//
// Cases with empty consumes exercise the buildHTTP-direct entry point
// (i.e. external callers of BuildHTTP that have already picked a mime
// without going through createHttpRequest).
// without going through createHTTPRequest).
func payloadReaderCases() iter.Seq[buildHTTPCase] {
return slices.Values([]buildHTTPCase{
{
Expand Down
14 changes: 0 additions & 14 deletions client/internal/request/consts_test.go

This file was deleted.

111 changes: 92 additions & 19 deletions client/internal/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var _ runtime.ClientRequest = new(Request) // ensure compliance to the interface
//
// # Request binding
//
// The binding of parameters is carried out by method [request.BuildHTTP].
// The binding of parameters is carried out by method [Request.BuildHTTPContext].
//
// It analyzes parameters, which may come in different flavors:
//
Expand All @@ -52,8 +52,8 @@ var _ runtime.ClientRequest = new(Request) // ensure compliance to the interface
// - file, multipart form or io.Reader body: a streaming request with an attached go routine that consumes the [io.Reader].
// - buffered body: a simple request
//
// In all cases, it is left to the caller to set the request's [context.Context]: [request.BuildHTTP] only builds
// requests with [context.Background].
// The caller passes the parent [context.Context] to [Request.BuildHTTPContext] and receives back a cancel
// function to release the resources held by the derived request context once the response is consumed.
//
// # Authentication
//
Expand Down Expand Up @@ -289,7 +289,9 @@ func (r *Request) SetConsumes(consumers []string) {
r.consumes = consumers
}

// BuildHTTP dispatches to one of two end-to-end builders based on whether:
// BuildHTTPContext binds the request parameters and returns a ready-to-send [http.Request].
//
// Dispatch picks one of two end-to-end builders based on whether:
//
// - the body source is a stream (multipart pipe or stream payload)
// - or a buffer (urlencoded form, producer output, or no body)
Expand All @@ -299,17 +301,56 @@ func (r *Request) SetConsumes(consumers []string) {
//
// The split mirrors the auth question: streaming bodies require a lazy body-copy closure during AuthenticateRequest,
// whereas buffered bodies do not.
func (r *Request) BuildHTTP(mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (*http.Request, error) {
//
// The returned [http.Request] carries a context derived from parentCtx that:
//
// - inherits any deadline or cancellation already set on parentCtx;
// - additionally honors the per-request timeout set via [Request.SetTimeout]
// (the [runtime.ClientRequestWriter] may override the runtime default during
// WriteToRequest, which is why the derivation happens here rather than
// at the call site).
//
// The returned cancel must be invoked by the caller (typically deferred)
// once the response has been fully read; otherwise resources held by the
// derived context — including any timeout timer — are leaked.
//
// On error the cancel is invoked internally and a no-op cancel is returned,
// so callers can defer cancel unconditionally.
func (r *Request) BuildHTTPContext(parentCtx context.Context, mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (*http.Request, context.CancelFunc, error) {
if err := r.writer.WriteToRequest(r, registry); err != nil {
return nil, err
return nil, noop, err
}

ctx, cancel := deriveRequestContext(parentCtx, r.timeout)
r.buf = bytes.NewBuffer(nil)

var (
httpReq *http.Request
err error
)
if r.usesStreamingBody(mediaType) {
return r.buildStreamingRequest(mediaType, basePath, producers, registry, auth)
httpReq, err = r.buildStreamingRequest(ctx, mediaType, basePath, producers, registry, auth)
} else {
httpReq, err = r.buildBufferedRequest(ctx, mediaType, basePath, producers, registry, auth)
}
if err != nil {
cancel()
return nil, noop, err
}
return httpReq, cancel, nil
}

func noop() {}

// deriveRequestContext returns a child of parent bounded by timeout.
// If timeout == 0 the child is only canceled when the caller invokes
// cancel; any deadline already on parent is preserved. If timeout > 0
// the child uses the shortest of timeout and parent's existing deadline.
func deriveRequestContext(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if timeout == 0 {
return context.WithCancel(parent)
}
return r.buildBufferedRequest(mediaType, basePath, producers, registry, auth)
return context.WithTimeout(parent, timeout)
}

// usesStreamingBody reports whether the request body must be assembled
Expand Down Expand Up @@ -368,7 +409,7 @@ func (r *Request) isMultipart(mediaType string) bool {
//
// Auth is trivial in this flow because the buffer is already populated when the auth helper
// asks for the body via r.GetBody().
func (r *Request) buildBufferedRequest(mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (*http.Request, error) {
func (r *Request) buildBufferedRequest(ctx context.Context, mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (*http.Request, error) {
var body io.Reader
var err error

Expand All @@ -392,7 +433,7 @@ func (r *Request) buildBufferedRequest(mediaType, basePath string, producers map
}
}

return r.assembleRequest(basePath, body)
return r.assembleRequest(ctx, basePath, body)
}

// buildStreamingRequest assembles a request whose body is a stream —
Expand All @@ -409,10 +450,10 @@ func (r *Request) buildBufferedRequest(mediaType, basePath string, producers map
// (it would otherwise park forever on pw.Write with no reader).
//
// For stream payloads it closes the user-provided io.ReadCloser.
func (r *Request) buildStreamingRequest(mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (req *http.Request, retErr error) {
func (r *Request) buildStreamingRequest(ctx context.Context, mediaType, basePath string, producers map[string]runtime.Producer, registry strfmt.Registry, auth runtime.ClientAuthInfoWriter) (req *http.Request, retErr error) {
var body io.Reader
if len(r.formFields) > 0 || len(r.fileFields) > 0 {
body = r.writeMultipartBody(mediaType)
body = r.writeMultipartBody(ctx, mediaType)
} else {
body = r.writeStreamPayload(mediaType, producers)
}
Expand All @@ -435,19 +476,19 @@ func (r *Request) buildStreamingRequest(mediaType, basePath string, producers ma
return nil, err
}

return r.assembleRequest(basePath, body)
return r.assembleRequest(ctx, basePath, body)
}

// assembleRequest is the shared tail of both flows: build the URL
// path, create the http.Request, merge static query parameters, and
// finalize headers/query.
func (r *Request) assembleRequest(basePath string, body io.Reader) (*http.Request, error) {
func (r *Request) assembleRequest(ctx context.Context, basePath string, body io.Reader) (*http.Request, error) {
urlPath, staticQueryParams, err := r.resolveURLPath(basePath)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(context.Background(), r.method, urlPath, body)
req, err := http.NewRequestWithContext(ctx, r.method, urlPath, body)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -625,12 +666,12 @@ func (r *Request) writeURLEncodedBody(mediaType string) (io.Reader, error) {
// The goroutine owns the pipe writer's lifecycle: it closes the
// multipart writer (flushing the closing boundary) and the pipe writer
// when it finishes or hits an error.
func (r *Request) writeMultipartBody(mediaType string) io.Reader {
func (r *Request) writeMultipartBody(ctx context.Context, mediaType string) io.Reader {
pr, pw := io.Pipe()
mp := multipart.NewWriter(pw)
r.header.Set(runtime.HeaderContentType, mangleContentType(mediaType, mp.Boundary()))

go r.streamMultipartParts(mp, pw)
go r.streamMultipartParts(ctx, mp, pw)

return pr
}
Expand All @@ -639,14 +680,23 @@ func (r *Request) writeMultipartBody(mediaType string) io.Reader {
// closing mp and pw when done.
//
// Errors are reported by closing pw with the error so the consumer of pr observes them on its next Read.
func (r *Request) streamMultipartParts(mp *multipart.Writer, pw *io.PipeWriter) {
//
// Context cancellation is observed at iteration boundaries (between
// fields and between files) and during file copy via a context-aware
// reader. When ctx is canceled the pipe writer is closed with ctx.Err()
// so the body consumer surfaces the cancellation as the read error.
func (r *Request) streamMultipartParts(ctx context.Context, mp *multipart.Writer, pw *io.PipeWriter) {
defer func() {
mp.Close()
pw.Close()
}()

for fn, v := range r.formFields {
for _, vi := range v {
if err := ctx.Err(); err != nil {
_ = pw.CloseWithError(err)
return
}
if err := mp.WriteField(fn, vi); err != nil {
logClose(err, pw)
return
Expand All @@ -664,6 +714,11 @@ func (r *Request) streamMultipartParts(mp *multipart.Writer, pw *io.PipeWriter)

for fn, f := range r.fileFields {
for _, fi := range f {
if err := ctx.Err(); err != nil {
_ = pw.CloseWithError(err)
return
}

var fileContentType string
if p, ok := fi.(runtime.ContentTyper); ok {
fileContentType = p.ContentType()
Expand Down Expand Up @@ -692,13 +747,31 @@ func (r *Request) streamMultipartParts(mp *multipart.Writer, pw *io.PipeWriter)
logClose(err, pw)
return
}
if _, err := io.Copy(wrtr, fi); err != nil {
if _, err := io.Copy(wrtr, &ctxReader{ctx: ctx, r: fi}); err != nil {
logClose(err, pw)
return
}
}
}
}

// ctxReader wraps an [io.Reader] with a context check on each Read. Once
// ctx is done, subsequent Reads return ctx.Err() instead of delegating
// to the underlying reader. It does not preempt a Read already in flight
// — that is the source's responsibility (e.g. *os.File honors Close from
// another goroutine, network sources honor SetDeadline).
type ctxReader struct {
ctx context.Context //nolint:containedctx // io.Reader's Read method has no ctx parameter, so the wrapper must carry it on the struct
r io.Reader
}

func (cr *ctxReader) Read(p []byte) (int, error) {
if err := cr.ctx.Err(); err != nil {
return 0, err
}
return cr.r.Read(p)
}

// writeStreamPayload handles a stream payload (io.Reader /
// io.ReadCloser). The bytes flow through verbatim — no producer is
// invoked. The wire Content-Type is resolved via setStreamContentType
Expand Down
Loading
Loading