forked from gravitational/teleport
/
tls.go
196 lines (178 loc) · 5.38 KB
/
tls.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
/*
* Teleport
* Copyright (C) 2023 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package multiplexer
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"time"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/utils"
)
// TLSListenerConfig specifies listener configuration
type TLSListenerConfig struct {
// Listener is the listener returning *tls.Conn
// connections on Accept
Listener net.Listener
// ID is an identifier used for debugging purposes
ID string
// ReadDeadline is a connection read deadline during the TLS handshake (start
// of the connection). It is set to defaults.HandshakeReadDeadline if
// unspecified.
ReadDeadline time.Duration
// Clock is a clock to override in tests, set to real time clock
// by default
Clock clockwork.Clock
}
// CheckAndSetDefaults verifies configuration and sets defaults
func (c *TLSListenerConfig) CheckAndSetDefaults() error {
if c.Listener == nil {
return trace.BadParameter("missing parameter Listener")
}
if c.ReadDeadline == 0 {
c.ReadDeadline = defaults.HandshakeReadDeadline
}
if c.Clock == nil {
c.Clock = clockwork.NewRealClock()
}
return nil
}
// NewTLSListener returns a new TLS listener
func NewTLSListener(cfg TLSListenerConfig) (*TLSListener, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
context, cancel := context.WithCancel(context.TODO())
return &TLSListener{
log: log.WithFields(log.Fields{
trace.Component: teleport.Component("mxtls", cfg.ID),
}),
cfg: cfg,
http2Listener: newListener(context, cfg.Listener.Addr()),
httpListener: newListener(context, cfg.Listener.Addr()),
cancel: cancel,
context: context,
}, nil
}
// TLSListener wraps tls.Listener and detects negotiated protocol
// (assuming it's either http/1.1 or http/2)
// and forwards the appropriate responses to either HTTP/1.1 or HTTP/2
// listeners
type TLSListener struct {
log *log.Entry
cfg TLSListenerConfig
http2Listener *Listener
httpListener *Listener
cancel context.CancelFunc
context context.Context
}
// HTTP2 returns HTTP2 listener
func (l *TLSListener) HTTP2() net.Listener {
return l.http2Listener
}
// HTTP returns HTTP listener
func (l *TLSListener) HTTP() net.Listener {
return l.httpListener
}
// Serve accepts and forwards tls.Conn connections
func (l *TLSListener) Serve() error {
for {
conn, err := l.cfg.Listener.Accept()
if err == nil {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
conn.Close()
l.log.WithFields(log.Fields{
"src_addr": conn.RemoteAddr(),
"dst_addr": conn.LocalAddr(),
}).Errorf("Expected tls.Conn, got %T, internal usage error.", conn)
continue
}
go l.detectAndForward(tlsConn)
continue
}
if utils.IsUseOfClosedNetworkError(err) {
<-l.context.Done()
return nil
}
select {
case <-l.context.Done():
return nil
case <-time.After(5 * time.Second):
}
}
}
func (l *TLSListener) detectAndForward(conn *tls.Conn) {
err := conn.SetReadDeadline(l.cfg.Clock.Now().Add(l.cfg.ReadDeadline))
if err != nil {
l.log.WithError(err).Debugf("Failed to set connection deadline.")
conn.Close()
return
}
start := l.cfg.Clock.Now()
if err := conn.Handshake(); err != nil {
if !errors.Is(trace.Unwrap(err), io.EOF) {
l.log.WithFields(log.Fields{
"src_addr": conn.RemoteAddr(),
"dst_addr": conn.LocalAddr(),
}).WithError(err).Warning("Handshake failed.")
}
conn.Close()
return
}
// Log warning if TLS handshake takes more than one second to help debug
// latency issues.
if elapsed := time.Since(start); elapsed > 1*time.Second {
l.log.Warnf("Slow TLS handshake from %v, took %v.", conn.RemoteAddr(), time.Since(start))
}
err = conn.SetReadDeadline(time.Time{})
if err != nil {
l.log.WithError(err).Warning("Failed to reset read deadline")
conn.Close()
return
}
switch conn.ConnectionState().NegotiatedProtocol {
case http2.NextProtoTLS:
l.http2Listener.HandleConnection(l.context, conn)
case teleport.HTTPNextProtoTLS, "":
l.httpListener.HandleConnection(l.context, conn)
default:
conn.Close()
l.log.WithFields(log.Fields{
"src_addr": conn.RemoteAddr(),
"dst_addr": conn.LocalAddr(),
}).WithError(err).Errorf("unsupported protocol: %v", conn.ConnectionState().NegotiatedProtocol)
}
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (l *TLSListener) Close() error {
defer l.cancel()
return l.cfg.Listener.Close()
}
// Addr returns the listener's network address.
func (l *TLSListener) Addr() net.Addr {
return l.cfg.Listener.Addr()
}