-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
port_forwarder.go
414 lines (355 loc) · 12.5 KB
/
port_forwarder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
package portforwarder
import (
"context"
"fmt"
"io"
"net"
"strings"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
const (
githubSubjectId = "1"
InternalPortTag = "InternalPort"
UserForwardedPortTag = "UserForwardedPort"
)
const (
PrivatePortVisibility = "private"
OrgPortVisibility = "org"
PublicPortVisibility = "public"
)
const (
trafficTypeInput = "input"
trafficTypeOutput = "output"
)
type ForwardPortOpts struct {
Port int
Internal bool
KeepAlive bool
Visibility string
}
type CodespacesPortForwarder struct {
connection connection.CodespaceConnection
keepAliveReason chan string
}
type PortForwarder interface {
ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error
ForwardPort(ctx context.Context, opts ForwardPortOpts) error
ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error
ListPorts(ctx context.Context) ([]*tunnels.TunnelPort, error)
UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error
KeepAlive(reason string)
GetKeepAliveReason() string
Close() error
}
// NewPortForwarder returns a new PortForwarder for the specified codespace.
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd PortForwarder, err error) {
return &CodespacesPortForwarder{
connection: *codespaceConnection,
keepAliveReason: make(chan string, 1),
}, nil
}
// ForwardPortToListener forwards the specified port to the given TCP listener.
func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error {
err := fwd.ForwardPort(ctx, opts)
if err != nil {
return fmt.Errorf("error forwarding port: %w", err)
}
done := make(chan error)
go func() {
// Convert the port number to a uint16
port, err := convertIntToUint16(opts.Port)
if err != nil {
done <- fmt.Errorf("error converting port: %w", err)
return
}
// Ensure the port is forwarded before connecting
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, port)
if err != nil {
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
return
}
// Connect to the forwarded port
err = fwd.connectListenerToForwardedPort(ctx, opts, listener)
if err != nil {
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
}
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("error connecting to tunnel: %w", err)
}
return nil
case <-ctx.Done():
return nil
}
}
// ForwardPort informs the host that we would like to forward the given port.
func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts ForwardPortOpts) error {
// Convert the port number to a uint16
port, err := convertIntToUint16(opts.Port)
if err != nil {
return fmt.Errorf("error converting port: %w", err)
}
tunnelPort := tunnels.NewTunnelPort(port, "", "", tunnels.TunnelProtocolHttp)
// If no visibility is provided, Dev Tunnels will use the default (private)
if opts.Visibility != "" {
// Check if the requested visibility is allowed
allowed := false
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
if allowedVisibility == opts.Visibility {
allowed = true
break
}
}
// If the requested visibility is not allowed, return an error
if !allowed {
return fmt.Errorf("visibility %s is not allowed", opts.Visibility)
}
accessControlEntries := visibilityToAccessControlEntries(opts.Visibility)
if len(accessControlEntries) > 0 {
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
Entries: accessControlEntries,
}
}
}
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
if opts.Internal {
tunnelPort.Tags = []string{InternalPortTag}
} else {
tunnelPort.Tags = []string{UserForwardedPortTag}
}
// Create the tunnel port
_, err = fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
if err != nil && !strings.Contains(err.Error(), "409") {
return fmt.Errorf("create tunnel port failed: %v", err)
}
// Connect to the tunnel
err = fwd.connection.Connect(ctx)
if err != nil {
return fmt.Errorf("connect failed: %v", err)
}
// Inform the host that we've forwarded the port locally
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
return fmt.Errorf("refresh ports failed: %v", err)
}
return nil
}
// connectListenerToForwardedPort connects to the forwarded port via a local TCP port.
func (fwd *CodespacesPortForwarder) connectListenerToForwardedPort(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) (err error) {
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}
go func() {
for {
conn, err := listener.AcceptTCP()
if err != nil {
sendError(err)
return
}
// Connect to the forwarded port in a goroutine so we can accept new connections
go func() {
if err := fwd.ConnectToForwardedPort(ctx, conn, opts); err != nil {
sendError(err)
}
}()
}
}()
// Wait for an error or for the context to be cancelled
select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err() // canceled
}
}
// ConnectToForwardedPort connects to the forwarded port via a given ReadWriteCloser.
// Optionally, it detects traffic over the connection and sends activity signals to the server to keep the codespace from shutting down.
func (fwd *CodespacesPortForwarder) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error {
// Create a traffic monitor to keep the session alive
if opts.KeepAlive {
conn = newTrafficMonitor(conn, fwd)
}
// Convert the port number to a uint16
port, err := convertIntToUint16(opts.Port)
if err != nil {
return fmt.Errorf("error converting port: %w", err)
}
// Connect to the forwarded port
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, conn, port)
if err != nil {
return fmt.Errorf("error connecting to forwarded port: %w", err)
}
return nil
}
// ListPorts fetches the list of ports that are currently forwarded.
func (fwd *CodespacesPortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options)
if err != nil {
return nil, fmt.Errorf("error listing ports: %w", err)
}
return ports, nil
}
// UpdatePortVisibility changes the visibility (private, org, public) of the specified port.
func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options)
if err != nil {
return fmt.Errorf("error getting tunnel port: %w", err)
}
// If the port visibility isn't changing, don't do anything
if AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries) == visibility {
return nil
}
// Delete the existing tunnel port to update
err = fwd.connection.TunnelManager.DeleteTunnelPort(ctx, fwd.connection.Tunnel, uint16(remotePort), fwd.connection.Options)
if err != nil {
return fmt.Errorf("error deleting tunnel port: %w", err)
}
done := make(chan error)
go func() {
// Connect to the tunnel
err = fwd.connection.Connect(ctx)
if err != nil {
done <- fmt.Errorf("connect failed: %v", err)
return
}
// Inform the host that we've deleted the port
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
done <- fmt.Errorf("refresh ports failed: %v", err)
return
}
// Re-forward the port with the updated visibility
err = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: visibility})
if err != nil {
done <- fmt.Errorf("error forwarding port: %w", err)
return
}
done <- nil
}()
// Wait for the done channel to be closed
select {
case err := <-done:
if err != nil {
// If we fail to re-forward the port, we need to forward again with the original visibility so the port is still accessible
_ = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries)})
return fmt.Errorf("error connecting to tunnel: %w", err)
}
return nil
case <-ctx.Done():
return nil
}
}
// KeepAlive accepts a reason that is retained if there is no active reason
// to send to the server.
func (fwd *CodespacesPortForwarder) KeepAlive(reason string) {
select {
case fwd.keepAliveReason <- reason:
default:
// there is already an active keep alive reason
// so we can ignore this one
}
}
// GetKeepAliveReason fetches the keep alive reason from the channel and returns it.
func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string {
return <-fwd.keepAliveReason
}
// Close closes the port forwarder's tunnel client connection.
func (fwd *CodespacesPortForwarder) Close() error {
return fwd.connection.Close()
}
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string {
for _, entry := range accessControlEntries {
// If we have the anonymous type (and we're not denying it), it's public
if (entry.Type == tunnels.TunnelAccessControlEntryTypeAnonymous) && (!entry.IsDeny) {
return PublicPortVisibility
}
// If we have the organizations type (and we're not denying it), it's org
if (entry.Provider == string(tunnels.TunnelAuthenticationSchemeGitHub)) && (!entry.IsDeny) {
return OrgPortVisibility
}
}
// Else, it's private
return PrivatePortVisibility
}
// visibilityToAccessControlEntries converts the given visibility to access control entries that can be used by Dev Tunnels.
func visibilityToAccessControlEntries(visibility string) []tunnels.TunnelAccessControlEntry {
switch visibility {
case PublicPortVisibility:
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
Subjects: []string{},
Scopes: []string{string(tunnels.TunnelAccessScopeConnect)},
}}
case OrgPortVisibility:
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
Subjects: []string{githubSubjectId},
Scopes: []string{
string(tunnels.TunnelAccessScopeConnect),
},
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
}}
default:
// The tunnel manager doesn't accept empty access control entries, so we need to return a deny entry
return []tunnels.TunnelAccessControlEntry{{
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
Subjects: []string{githubSubjectId},
Scopes: []string{},
IsDeny: true,
}}
}
}
// IsInternalPort returns true if the port is internal.
func IsInternalPort(port *tunnels.TunnelPort) bool {
for _, tag := range port.Tags {
if strings.EqualFold(tag, InternalPortTag) {
return true
}
}
return false
}
// convertIntToUint16 converts the given int to a uint16.
func convertIntToUint16(port int) (uint16, error) {
var updatedPort uint16
if port >= 0 && port <= 65535 {
updatedPort = uint16(port)
} else {
return 0, fmt.Errorf("invalid port number: %d", port)
}
return updatedPort, nil
}
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
// it of the traffic type during Read operations.
type trafficMonitor struct {
rwc io.ReadWriteCloser
fwd PortForwarder
}
// newTrafficMonitor returns a trafficMonitor for the specified codespace connection.
// It wraps the provided io.ReaderWriteCloser with its own Read/Write/Close methods.
func newTrafficMonitor(rwc io.ReadWriteCloser, fwd PortForwarder) *trafficMonitor {
return &trafficMonitor{rwc, fwd}
}
// Read wraps the underlying ReadWriteCloser's Read method and keeps the session alive with the "input" traffic type.
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
t.fwd.KeepAlive(trafficTypeInput)
return t.rwc.Read(p)
}
// Write wraps the underlying ReadWriteCloser's Write method and keeps the session alive with the "output" traffic type.
func (t *trafficMonitor) Write(p []byte) (n int, err error) {
t.fwd.KeepAlive(trafficTypeOutput)
return t.rwc.Write(p)
}
// Close closes the underlying ReadWriteCloser.
func (t *trafficMonitor) Close() error {
return t.rwc.Close()
}