/
portforward.go
121 lines (104 loc) · 3.35 KB
/
portforward.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package k8s
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"k8s.io/client-go/kubernetes/typed/core/v1"
_ "k8s.io/client-go/plugin/pkg/client/auth/gcp" // registers gcp auth provider
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/transport/spdy"
"github.com/windmilleng/tilt/internal/logger"
"github.com/pkg/errors"
)
func (k K8sClient) ForwardPort(ctx context.Context, namespace Namespace, podID PodID, optionalLocalPort, remotePort int) (localPort int, closer func(), err error) {
localPort = optionalLocalPort
if localPort == 0 {
// preferably, we'd set the localport to 0, and let the underlying function pick a port for us,
// to avoid the race condition potential of something else grabbing this port between
// the call to `getAvailablePort` and whenever `portForwarder` actually binds the port.
// the k8s client supports a local port of 0, and stores the actual local port assigned in a field,
// but unfortunately does not export that field, so there is no way for the caller to know which
// local port to talk to.
localPort, err = getAvailablePort()
if err != nil {
return 0, nil, errors.Wrap(err, "failed to find an available local port")
}
}
closer, err = k.portForwarder(ctx, k.restConfig, k.core, namespace.String(), podID, localPort, remotePort)
if err != nil {
return 0, nil, err
}
return localPort, closer, nil
}
func portForwarder(ctx context.Context, restConfig *rest.Config, core v1.CoreV1Interface, namespace string, podID PodID, localPort int, remotePort int) (closer func(), err error) {
transport, upgrader, err := spdy.RoundTripperFor(restConfig)
if err != nil {
return nil, errors.Wrap(err, "error getting roundtripper")
}
req := core.RESTClient().Post().
Resource("pods").
Namespace(namespace).
Name(podID.String()).
SubResource("portforward")
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL())
if err != nil {
return nil, errors.Wrap(err, "error creating dialer")
}
stopChan := make(chan struct{}, 1)
readyChan := make(chan struct{}, 1)
ports := []string{fmt.Sprintf("%d:%d", localPort, remotePort)}
pf, err := portforward.New(
dialer,
ports,
stopChan,
readyChan,
logger.Get(ctx).Writer(logger.DebugLvl),
logger.Get(ctx).Writer(logger.DebugLvl))
if err != nil {
return nil, errors.Wrap(err, "error forwarding port")
}
errChan := make(chan error)
go func() {
errChan <- pf.ForwardPorts()
err := <-errChan
pf.Close()
// logging isn't really sufficient, since we're in a goroutine and who knows where the caller
// has moved on to by this point, but other options are much more expensive (e.g., monitoring the state
// of the port forward from the caller and/or automatically reconnecting port forwards)
logger.Get(ctx).Infof("error from port forward: %v", err)
}()
select {
case err = <-errChan:
pf.Close()
return nil, errors.Wrap(err, "error forwarding port")
case <-pf.Ready:
closer = func() {
close(stopChan)
}
return closer, nil
}
}
func getAvailablePort() (int, error) {
l, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
defer func() {
e := l.Close()
if err == nil {
err = e
}
}()
_, p, err := net.SplitHostPort(l.Addr().String())
if err != nil {
return 0, err
}
port, err := strconv.Atoi(p)
if err != nil {
return 0, err
}
return port, err
}