Skip to content

Commit

Permalink
fix SetTLSFingerprintXXX does not take effect in subsequent requests(#…
Browse files Browse the repository at this point in the history
  • Loading branch information
imroc committed Oct 6, 2023
1 parent 21697a2 commit bc2bf86
Showing 1 changed file with 74 additions and 20 deletions.
94 changes: 74 additions & 20 deletions internal/http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ import (
"crypto/tls"
"errors"
"fmt"
"github.com/imroc/req/v3/http2"
"github.com/imroc/req/v3/internal/ascii"
"github.com/imroc/req/v3/internal/common"
"github.com/imroc/req/v3/internal/dump"
"github.com/imroc/req/v3/internal/header"
"github.com/imroc/req/v3/internal/netutil"
"github.com/imroc/req/v3/internal/transport"
reqtls "github.com/imroc/req/v3/pkg/tls"
"io"
"io/fs"
"log"
Expand All @@ -43,6 +35,15 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"

"github.com/imroc/req/v3/http2"
"github.com/imroc/req/v3/internal/ascii"
"github.com/imroc/req/v3/internal/common"
"github.com/imroc/req/v3/internal/dump"
"github.com/imroc/req/v3/internal/header"
"github.com/imroc/req/v3/internal/netutil"
"github.com/imroc/req/v3/internal/transport"
reqtls "github.com/imroc/req/v3/pkg/tls"
)

const (
Expand Down Expand Up @@ -157,7 +158,6 @@ func (t *Transport) pingTimeout() time.Duration {
return 15 * time.Second
}
return t.PingTimeout

}

func (t *Transport) connPool() ClientConnPool {
Expand Down Expand Up @@ -585,18 +585,72 @@ func (t *Transport) newTLSConfig(host string) *tls.Config {
return cfg
}

var zeroDialer net.Dialer

type tlsHandshakeTimeoutError struct{}

func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }

// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
// connection.
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
if t.TLSHandshakeContext != nil {
conn, err := zeroDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
var firstTLSHost string
if firstTLSHost, _, err = net.SplitHostPort(addr); err != nil {
return nil, err
}
trace := httptrace.ContextClientTrace(ctx)
errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake
if d := t.TLSHandshakeTimeout; d != 0 {
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
}
go func() {
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
tlsCn, tlsState, err := t.TLSHandshakeContext(ctx, firstTLSHost, conn)
if err != nil {
if timer != nil {
timer.Stop()
}
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
} else {
conn = tlsCn
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(*tlsState, nil)
}
}
errc <- err
}()
if err := <-errc; err != nil {
conn.Close()
return nil, err
} else {
tlsCn := conn.(reqtls.Conn)
return tlsCn, nil
}
} else {
dialer := &tls.Dialer{
Config: cfg,
}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tlsCn := conn.(reqtls.Conn)
return tlsCn, nil
}
tlsCn := conn.(reqtls.Conn)
return tlsCn, nil
}

func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) {
Expand Down Expand Up @@ -1771,7 +1825,6 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
if a := cs.flow.available(); a > 0 {
take := a
if int(take) > maxBytes {

take = int32(maxBytes) // can't truncate int; take is int32
}
if take > int32(cc.maxFrameSize) {
Expand Down Expand Up @@ -1928,7 +1981,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
break
}
vals = append(vals, v[:p])
//writeHeader("cookie", v[:p])
// writeHeader("cookie", v[:p])
p++
// strip space after semicolon if any.
for p+1 <= len(v) && v[p] == ' ' {
Expand All @@ -1938,7 +1991,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
}
if len(v) > 0 {
vals = append(vals, v)
//writeHeader("cookie", v)
// writeHeader("cookie", v)
}
}
writeHeader("cookie", vals...)
Expand Down Expand Up @@ -2641,6 +2694,7 @@ func (b transportResponseBody) Close() error {
}
return nil
}

func (rl *clientConnReadLoop) processData(f *DataFrame) error {
cc := rl.cc
cs := rl.streamByID(f.StreamID)
Expand Down

0 comments on commit bc2bf86

Please sign in to comment.