/
engineconn.go
107 lines (90 loc) · 2.38 KB
/
engineconn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package engineconn
import (
"context"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
"github.com/Khan/genqlient/graphql"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
)
type EngineConn interface {
graphql.Doer
Host() string
Close() error
}
type Config struct {
Workdir string
LogOutput io.Writer
Conn EngineConn
SkipCompatibilityCheck bool
}
type ConnectParams struct {
Port int `json:"port"`
SessionToken string `json:"session_token"`
}
func Get(ctx context.Context, cfg *Config) (EngineConn, error) {
// Prefer explicitly set conn
if cfg.Conn != nil {
return cfg.Conn, nil
}
// Try DAGGER_SESSION_PORT next
conn, ok, err := FromSessionEnv()
if err != nil {
return nil, err
}
if ok {
if cfg.Workdir != "" {
return nil, fmt.Errorf("cannot configure workdir for existing session (please use --workdir or host.directory with absolute paths instead)")
}
return conn, nil
}
// Try _EXPERIMENTAL_DAGGER_CLI_BIN next
conn, ok, err = FromLocalCLI(ctx, cfg)
if err != nil {
return nil, err
}
if ok {
return conn, nil
}
// Fallback to downloading the CLI
conn, err = FromDownloadedCLI(ctx, cfg)
if err != nil {
return nil, err
}
return conn, nil
}
func fallbackSpanContext(ctx context.Context) context.Context {
if trace.SpanContextFromContext(ctx).IsValid() {
return ctx
}
if p, ok := os.LookupEnv("TRACEPARENT"); ok {
slog.Debug("falling back to $TRACEPARENT", "value", p)
return propagation.TraceContext{}.Extract(ctx, propagation.MapCarrier{"traceparent": p})
}
return ctx
}
func defaultHTTPClient(p *ConnectParams) *http.Client {
dialTransport := &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", p.Port))
},
}
return &http.Client{
Transport: RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
r.SetBasicAuth(p.SessionToken, "")
// detect $TRACEPARENT set by 'dagger run'
r = r.WithContext(fallbackSpanContext(r.Context()))
// propagate span context via headers (i.e. for Dagger-in-Dagger)
propagation.TraceContext{}.Inject(r.Context(), propagation.HeaderCarrier(r.Header))
return dialTransport.RoundTrip(r)
}),
}
}
type RoundTripperFunc func(*http.Request) (*http.Response, error)
func (f RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}