Skip to content

Commit

Permalink
add default timeout for transporters
Browse files Browse the repository at this point in the history
  • Loading branch information
ginuerzh committed Dec 22, 2018
1 parent 369b18b commit 99a0804
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 31 deletions.
14 changes: 12 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er
option(opts)
}

timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}
if opts.Chain == nil {
return net.DialTimeout("tcp", addr, opts.Timeout)
return net.DialTimeout("tcp", addr, timeout)
}
return opts.Chain.Dial(addr)
}
Expand All @@ -103,7 +107,13 @@ func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er
for _, option := range options {
option(opts)
}
return net.DialTimeout("udp", addr, opts.Timeout)

timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}

return net.DialTimeout("udp", addr, timeout)
}

func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
Expand Down
23 changes: 13 additions & 10 deletions forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,9 +662,9 @@ func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) {

go ln.listenLoop()

if err = <-ln.errChan; err != nil {
ln.Close()
}
// if err = <-ln.errChan; err != nil {
// ln.Close()
// }

return ln, err
}
Expand All @@ -680,19 +680,22 @@ func (l *tcpRemoteForwardListener) isChainValid() bool {

func (l *tcpRemoteForwardListener) listenLoop() {
var tempDelay time.Duration
var once sync.Once
// var once sync.Once

for {
conn, err := l.accept()

once.Do(func() {
l.errChan <- err
close(l.errChan)
})
// once.Do(func() {
// l.errChan <- err
// log.Log("once.Do error:", err)
// close(l.errChan)
// })

select {
case <-l.closed:
conn.Close()
if conn != nil {
conn.Close()
}
return
default:
}
Expand All @@ -706,7 +709,7 @@ func (l *tcpRemoteForwardListener) listenLoop() {
if max := 6 * time.Second; tempDelay > max {
tempDelay = max
}
log.Logf("[rtcp] Accept error: %v; retrying in %v", err, tempDelay)
log.Logf("[rtcp] accept error: %v; retrying in %v", err, tempDelay)
time.Sleep(tempDelay)
continue
}
Expand Down
2 changes: 2 additions & 0 deletions gost.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ var (
KeepAliveTime = 180 * time.Second
// DialTimeout is the timeout of dial.
DialTimeout = 5 * time.Second
// HandshakeTimeout is the timeout of handshake.
HandshakeTimeout = 5 * time.Second
// ReadTimeout is the timeout for reading.
ReadTimeout = 5 * time.Second
// WriteTimeout is the timeout for writing.
Expand Down
17 changes: 13 additions & 4 deletions http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,23 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn,
}
conn.Close()

timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}
transport := http2.Transport{
TLSClientConfig: tr.tlsConfig,
DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) {
conn, err := opts.Chain.Dial(adr)
if err != nil {
return nil, err
}
return wrapTLSClient(conn, cfg, opts.Timeout)
return wrapTLSClient(conn, cfg, timeout)
},
}
client = &http.Client{
Transport: &transport,
Timeout: opts.Timeout,
Timeout: timeout,
}
tr.clients[addr] = client
}
Expand Down Expand Up @@ -190,6 +194,11 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err
tr.clientMutex.Lock()
client, ok := tr.clients[addr]
if !ok {
timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}

transport := http2.Transport{
TLSClientConfig: tr.tlsConfig,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
Expand All @@ -200,12 +209,12 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err
if tr.tlsConfig == nil {
return conn, nil
}
return wrapTLSClient(conn, cfg, opts.Timeout)
return wrapTLSClient(conn, cfg, timeout)
},
}
client = &http.Client{
Transport: &transport,
Timeout: opts.Timeout,
Timeout: timeout,
}
tr.clients[addr] = client
}
Expand Down
19 changes: 15 additions & 4 deletions kcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,21 @@ func KCPTransporter(config *KCPConfig) Transporter {
}

func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) {
uaddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return
opts := &DialOptions{}
for _, option := range options {
option(opts)
}

tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

session, ok := tr.sessions[addr]
if !ok {
conn, err = net.DialUDP("udp", nil, uaddr)
timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}
conn, err = net.DialTimeout("udp", addr, timeout)
if err != nil {
return
}
Expand All @@ -146,6 +150,13 @@ func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

timeout := opts.Timeout
if timeout <= 0 {
timeout = HandshakeTimeout
}
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})

session, ok := tr.sessions[opts.Addr]
if !ok || session.session == nil {
s, err := tr.initSession(opts.Addr, conn, config)
Expand Down
12 changes: 12 additions & 0 deletions quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func QUICTransporter(config *QUICConfig) Transporter {
}

func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) {
opts := &DialOptions{}
for _, option := range options {
option(opts)
}

tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

Expand Down Expand Up @@ -92,6 +97,13 @@ func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption)
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

timeout := opts.Timeout
if timeout <= 0 {
timeout = HandshakeTimeout
}
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})

session, ok := tr.sessions[opts.Addr]
if session != nil && session.conn != conn {
conn.Close()
Expand Down
30 changes: 25 additions & 5 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,15 @@ func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}

session, ok := tr.sessions[addr]
if !ok || session.Closed() {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
}
Expand All @@ -152,8 +157,13 @@ func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOp
option(opts)
}

timeout := opts.Timeout
if timeout <= 0 {
timeout = HandshakeTimeout
}

config := ssh.ClientConfig{
Timeout: opts.Timeout,
Timeout: timeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
if opts.User != nil {
Expand Down Expand Up @@ -222,10 +232,15 @@ func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn n
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}

session, ok := tr.sessions[addr]
if !ok || session.Closed() {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
}
Expand All @@ -248,8 +263,13 @@ func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOpt
option(opts)
}

timeout := opts.Timeout
if timeout <= 0 {
timeout = HandshakeTimeout
}

config := ssh.ClientConfig{
Timeout: opts.Timeout,
Timeout: timeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// TODO: support pubkey auth.
Expand Down Expand Up @@ -318,7 +338,7 @@ func (s *sshSession) Ping(interval, timeout time.Duration, retries int) {
return
}
if timeout <= 0 {
timeout = 10 * time.Second
timeout = PingTimeout
}

if retries == 0 {
Expand Down
25 changes: 22 additions & 3 deletions tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
if opts.TLSConfig == nil {
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
return wrapTLSClient(conn, opts.TLSConfig, opts.Timeout)

timeout := opts.Timeout
if timeout <= 0 {
timeout = HandshakeTimeout
}

return wrapTLSClient(conn, opts.TLSConfig, timeout)
}

type mtlsTransporter struct {
Expand All @@ -52,6 +58,11 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
option(opts)
}

timeout := opts.Timeout
if timeout <= 0 {
timeout = DialTimeout
}

tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

Expand All @@ -63,7 +74,7 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
}
if !ok {
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, opts.Timeout)
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
}
Expand All @@ -82,9 +93,17 @@ func (tr *mtlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption)
option(opts)
}

timeout := opts.Timeout
if timeout <= 0 {
timeout = HandshakeTimeout
}

tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()

conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})

session, ok := tr.sessions[opts.Addr]
if !ok || session.session == nil {
s, err := tr.initSession(opts.Addr, conn, opts)
Expand Down Expand Up @@ -265,7 +284,7 @@ func wrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration)
}

if timeout <= 0 {
timeout = 10 * time.Second // default timeout
timeout = HandshakeTimeout // default timeout
}

tlsConn.SetDeadline(time.Now().Add(timeout))
Expand Down
Loading

0 comments on commit 99a0804

Please sign in to comment.