diff --git a/internal/httptransport/httptransport.go b/internal/httptransport/httptransport.go index 1f5bc34..27e2568 100644 --- a/internal/httptransport/httptransport.go +++ b/internal/httptransport/httptransport.go @@ -7,6 +7,7 @@ import ( "time" "github.com/ooni/netx/internal/httptransport/toptripper" + "github.com/ooni/netx/internal/httptransport/transactioner" "github.com/ooni/netx/model" "golang.org/x/net/http2" ) @@ -35,9 +36,9 @@ func NewTransport(beginning time.Time, handler model.Handler) *Transport { TLSHandshakeTimeout: 10 * time.Second, }, } - transport.roundTripper = toptripper.New( + transport.roundTripper = transactioner.New(toptripper.New( beginning, handler, transport.Transport, - ) + )) // Configure h2 and make sure that the custom TLSConfig we use for dialing // is actually compatible with upgrading to h2. (This mainly means we // need to make sure we include "h2" in the NextProtos array.) Because diff --git a/internal/httptransport/transactioner/transactioner.go b/internal/httptransport/transactioner/transactioner.go new file mode 100644 index 0000000..0b6f210 --- /dev/null +++ b/internal/httptransport/transactioner/transactioner.go @@ -0,0 +1,55 @@ +// Package transactioner contains the transaction assigning round tripper +package transactioner + +import ( + "context" + "net/http" + "sync/atomic" +) + +type contextkey struct{} + +var id int64 + +// WithTransactionID returns a copy of ctx with TransactionID +func WithTransactionID(ctx context.Context) context.Context { + return context.WithValue( + ctx, contextkey{}, atomic.AddInt64(&id, 1), + ) +} + +// ContextTransactionID returns the TransactionID of the context, or zero +func ContextTransactionID(ctx context.Context) int64 { + id, _ := ctx.Value(contextkey{}).(int64) + return id +} + +// Transport performs single HTTP transactions. +type Transport struct { + roundTripper http.RoundTripper +} + +// New creates a new Transport. +func New(roundTripper http.RoundTripper) *Transport { + return &Transport{ + roundTripper: roundTripper, + } +} + +// RoundTrip executes a single HTTP transaction, returning +// a Response for the provided Request. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := WithTransactionID(req.Context()) + return t.roundTripper.RoundTrip(req.WithContext(ctx)) +} + +// CloseIdleConnections closes the idle connections. +func (t *Transport) CloseIdleConnections() { + // Adapted from net/http code + type closeIdler interface { + CloseIdleConnections() + } + if tr, ok := t.roundTripper.(closeIdler); ok { + tr.CloseIdleConnections() + } +} diff --git a/internal/httptransport/transactioner/transactioner_test.go b/internal/httptransport/transactioner/transactioner_test.go new file mode 100644 index 0000000..0a5498e --- /dev/null +++ b/internal/httptransport/transactioner/transactioner_test.go @@ -0,0 +1,55 @@ +package transactioner + +import ( + "io/ioutil" + "net/http" + "testing" +) + +type transport struct { + roundTripper http.RoundTripper + t *testing.T +} + +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + if id := ContextTransactionID(ctx); id == 0 { + t.t.Fatal("transaction ID not set") + } + return t.roundTripper.RoundTrip(req) +} + +func TestIntegration(t *testing.T) { + client := &http.Client{ + Transport: New(&transport{ + roundTripper: http.DefaultTransport, + t: t, + }), + } + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + client.CloseIdleConnections() +} + +func TestIntegrationFailure(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + // This fails the request because we attempt to speak cleartext HTTP with + // a server that instead is expecting TLS. + resp, err := client.Get("http://www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if resp != nil { + t.Fatal("expected a nil response here") + } + client.CloseIdleConnections() +}