Skip to content

Commit

Permalink
fix routing of host services to correct client
Browse files Browse the repository at this point in the history
Signed-off-by: Erik Sipsma <erik@sipsma.dev>
  • Loading branch information
sipsma committed Apr 30, 2024
1 parent 16df470 commit d975c04
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 11 deletions.
2 changes: 2 additions & 0 deletions core/c2h.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type c2hTunnel struct {
upstreamHost string
tunnelServiceHost string
tunnelServicePorts []PortForward
sessionID string
}

func (d *c2hTunnel) Tunnel(ctx context.Context) (rerr error) {
Expand Down Expand Up @@ -56,6 +57,7 @@ func (d *c2hTunnel) Tunnel(ctx context.Context) (rerr error) {
upstream := NewHostIPSocket(
port.Protocol.Network(),
fmt.Sprintf("%s:%d", d.upstreamHost, port.Backend),
d.sessionID,
)

sockPath := fmt.Sprintf("/upstream.%d.sock", frontend)
Expand Down
9 changes: 5 additions & 4 deletions core/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,12 @@ func (q *Query) NewTunnelService(upstream dagql.Instance[*Service], ports []Port
}
}

func (q *Query) NewHostService(upstream string, ports []PortForward) *Service {
func (q *Query) NewHostService(upstream string, ports []PortForward, sessionID string) *Service {
return &Service{
Query: q,
HostUpstream: upstream,
HostPorts: ports,
Query: q,
HostUpstream: upstream,
HostPorts: ports,
HostSessionID: sessionID,
}
}

Expand Down
59 changes: 56 additions & 3 deletions core/schema/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ func (s *hostSchema) Install() {
`If ports are given and native is true, the ports are additive.`),

dagql.Func("service", s.service).
Impure("Value depends on the caller as it points to their host.").
Doc(`Creates a service that forwards traffic to a specified address via the host.`).
ArgDoc("ports",
`Ports to expose via the service, forwarding through the host network.`,
Expand All @@ -168,6 +169,10 @@ func (s *hostSchema) Install() {
`An empty set of ports is not valid; an error will be returned.`).
ArgDoc("host", `Upstream host to forward traffic to.`),

// hidden from external clients via the __ prefix
dagql.Func("__internalService", s.internalService).
Doc(`(Internal-only) "service" but scoped to the exact right buildkit session ID.`),

dagql.Func("setSecretFile", s.setSecretFile).
Impure("`setSecretFile` reads its value from the local machine.").
Doc(
Expand Down Expand Up @@ -279,10 +284,58 @@ type hostServiceArgs struct {
Ports []dagql.InputObject[core.PortForward]
}

func (s *hostSchema) service(ctx context.Context, parent *core.Host, args hostServiceArgs) (*core.Service, error) {
func (s *hostSchema) service(ctx context.Context, parent *core.Host, args hostServiceArgs) (inst dagql.Instance[*core.Service], err error) {
if len(args.Ports) == 0 {
return nil, errors.New("no ports specified")
return inst, errors.New("no ports specified")
}

clientMetadata, err := engine.ClientMetadataFromContext(ctx)
if err != nil {
return inst, fmt.Errorf("failed to get client metadata: %w", err)
}

portsArg := make(dagql.ArrayInput[dagql.InputObject[core.PortForward]], len(args.Ports))
copy(portsArg, args.Ports)

err = s.srv.Select(ctx, s.srv.Root(), &inst,
dagql.Selector{
Field: "host",
},
dagql.Selector{
Field: "__internalService",
Args: []dagql.NamedInput{
{
Name: "host",
Value: dagql.NewString(args.Host),
},
{
Name: "ports",
Value: portsArg,
},
{
Name: "sessionId",
Value: dagql.NewString(clientMetadata.BuildkitSessionID()),
},
},
},
)
return inst, err
}

type hostInternalServiceArgs struct {
Host string `default:"localhost"`
Ports []dagql.InputObject[core.PortForward]
SessionID string
}

func (s *hostSchema) internalService(ctx context.Context, parent *core.Host, args hostInternalServiceArgs) (*core.Service, error) {
if args.SessionID == "" {
return nil, errors.New("no session ID specified")
}

return parent.Query.NewHostService(args.Host, collectInputsSlice(args.Ports)), nil
return parent.Query.NewHostService(
args.Host,
collectInputsSlice(args.Ports),
args.SessionID,
), nil
}
3 changes: 3 additions & 0 deletions core/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ type Service struct {
HostUpstream string `json:"reverse_tunnel_upstream_addr,omitempty"`
// HostPorts configures the port forwarding rules for the host.
HostPorts []PortForward `json:"host_ports,omitempty"`
// HostSessionID is the session ID of the host (could differ from main client in the case of nested execs).
HostSessionID string `json:"host_session_id,omitempty"`
}

func (*Service) Type() *ast.Type {
Expand Down Expand Up @@ -604,6 +606,7 @@ func (svc *Service) startReverseTunnel(ctx context.Context, id *call.ID) (runnin
upstreamHost: svc.HostUpstream,
tunnelServiceHost: fullHost,
tunnelServicePorts: svc.HostPorts,
sessionID: svc.HostSessionID,
}

checkPorts := []Port{}
Expand Down
7 changes: 6 additions & 1 deletion core/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type Socket struct {
// IP
HostProtocol string `json:"host_protocol,omitempty"`
HostAddr string `json:"host_addr,omitempty"`

// The session ID of the host's client
SessionID string `json:"session_id,omitempty"`
}

func (*Socket) Type() *ast.Type {
Expand All @@ -36,10 +39,11 @@ func NewHostUnixSocket(absPath string) *Socket {
}
}

func NewHostIPSocket(proto string, addr string) *Socket {
func NewHostIPSocket(proto string, addr string, sessionID string) *Socket {
return &Socket{
HostAddr: addr,
HostProtocol: proto,
SessionID: sessionID,
}
}

Expand All @@ -52,6 +56,7 @@ func (socket *Socket) SSHID() string {
default:
u.Scheme = socket.HostProtocol
u.Host = socket.HostAddr
u.Fragment = socket.SessionID
}
return u.String()
}
Expand Down
28 changes: 25 additions & 3 deletions engine/buildkit/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package buildkit

import (
"context"
"fmt"
"net/url"

"github.com/moby/buildkit/session/sshforward"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

type socketProxy struct {
Expand All @@ -29,10 +33,28 @@ func (p *socketProxy) ForwardAgent(stream sshforward.SSH_ForwardAgentServer) err

ctx = trace.ContextWithSpanContext(ctx, p.c.spanCtx) // ensure server's span context is propagated

incomingMD, _ := metadata.FromIncomingContext(ctx)
ctx = metadata.NewOutgoingContext(ctx, incomingMD)
opts, _ := metadata.FromIncomingContext(ctx)
ctx = metadata.NewOutgoingContext(ctx, opts)

forwardAgentClient, err := sshforward.NewSSHClient(p.c.MainClientCaller.Conn()).ForwardAgent(ctx)
var connURL *url.URL
if v, ok := opts[sshforward.KeySSHID]; ok && len(v) > 0 && v[0] != "" {
var err error
connURL, err = url.Parse(v[0])
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid id: %s", err)
}
}

caller := p.c.MainClientCaller
if sessionID := connURL.Fragment; sessionID != "" {
var err error
caller, err = p.c.SessionManager.Get(ctx, sessionID, true)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
}
}

forwardAgentClient, err := sshforward.NewSSHClient(caller.Conn()).ForwardAgent(ctx)
if err != nil {
return err
}
Expand Down

0 comments on commit d975c04

Please sign in to comment.