Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Finish isolating sessions to each client #6916

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 13 additions & 17 deletions cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/dagger/dagger/engine/client"
"github.com/dagger/dagger/network"
"github.com/google/uuid"
"github.com/opencontainers/go-digest"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/vito/progrock"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -452,7 +451,7 @@ func setupBundle() int {
keepEnv := []string{}
for _, env := range spec.Process.Env {
switch {
case strings.HasPrefix(env, "_DAGGER_ENABLE_NESTING="):
case strings.HasPrefix(env, "_DAGGER_NESTED_CLIENT_ID="):
// keep the env var; we use it at runtime
keepEnv = append(keepEnv, env)

Expand Down Expand Up @@ -484,7 +483,9 @@ func setupBundle() int {
Options: []string{"rbind"},
Source: execMetadata.ProgSockPath,
})
case strings.HasPrefix(env, "_DAGGER_SERVER_ID="):
case strings.HasPrefix(env, "_DAGGER_ENGINE_VERSION="):
// don't need this at runtime, it is just for invalidating cache, which
// has already happened by now
case strings.HasPrefix(env, aliasPrefix):
// NB: don't keep this env var, it's only for the bundling step
// keepEnv = append(keepEnv, env)
Expand Down Expand Up @@ -611,7 +612,8 @@ func internalEnv(name string) (string, bool) {
}

func runWithNesting(ctx context.Context, cmd *exec.Cmd) error {
if _, found := internalEnv("_DAGGER_ENABLE_NESTING"); !found {
clientID, ok := internalEnv("_DAGGER_NESTED_CLIENT_ID")
if !ok {
// no nesting; run as normal
return execProcess(cmd, true)
}
Expand All @@ -629,27 +631,21 @@ func runWithNesting(ctx context.Context, cmd *exec.Cmd) error {
}
sessionPort := l.Addr().(*net.TCPAddr).Port

serverID, ok := internalEnv("_DAGGER_SERVER_ID")
if !ok {
return errors.New("missing nested client server ID")
}

parentClientIDsVal, _ := internalEnv("_DAGGER_PARENT_CLIENT_IDS")

clientParams := client.Params{
ID: clientID,
ServerID: serverID,
SecretToken: sessionToken.String(),
RunnerHost: "unix:///.runner.sock",
ParentClientIDs: strings.Fields(parentClientIDsVal),
}

if _, ok := internalEnv("_DAGGER_ENABLE_NESTING_IN_SAME_SESSION"); ok {
serverID, ok := internalEnv("_DAGGER_SERVER_ID")
if !ok {
return fmt.Errorf("missing _DAGGER_SERVER_ID")
}
clientParams.ServerID = serverID
}

moduleCallerDigest, ok := internalEnv("_DAGGER_MODULE_CALLER_DIGEST")
if ok {
clientParams.ModuleCallerDigest = digest.Digest(moduleCallerDigest)
}

progW, err := progrock.DialRPC(ctx, "unix:///.progrock.sock")
if err != nil {
return fmt.Errorf("error connecting to progrock: %w", err)
Expand Down
42 changes: 14 additions & 28 deletions core/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -1021,16 +1021,15 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts

// this allows executed containers to communicate back to this API
if opts.ExperimentalPrivilegedNesting {
// include the engine version so that these execs get invalidated if the engine/API change
runOpts = append(runOpts, llb.AddEnv("_DAGGER_ENABLE_NESTING", engine.Version))
}

if opts.ModuleCallerDigest != "" {
runOpts = append(runOpts, llb.AddEnv("_DAGGER_MODULE_CALLER_DIGEST", opts.ModuleCallerDigest.String()))
}

if opts.NestedInSameSession {
runOpts = append(runOpts, llb.AddEnv("_DAGGER_ENABLE_NESTING_IN_SAME_SESSION", ""))
clientID, err := container.Query.RegisterCaller(ctx, opts.NestedExecFunctionCall)
if err != nil {
return nil, fmt.Errorf("register caller: %w", err)
}
runOpts = append(runOpts,
llb.AddEnv("_DAGGER_NESTED_CLIENT_ID", clientID),
// include the engine version so that these execs get invalidated if the engine/API change
llb.AddEnv("_DAGGER_ENGINE_VERSION", engine.Version),
)
}

metaSt, metaSourcePath := metaMount(opts.Stdin)
Expand Down Expand Up @@ -1071,13 +1070,7 @@ func (container *Container) WithExec(ctx context.Context, opts ContainerExecOpts
}

// don't pass these through to the container when manually set, they are internal only
if name == "_DAGGER_ENABLE_NESTING" && !opts.ExperimentalPrivilegedNesting {
continue
}
if name == "_DAGGER_MODULE_CALLER_DIGEST" && opts.ModuleCallerDigest == "" {
continue
}
if name == "_DAGGER_ENABLE_NESTING_IN_SAME_SESSION" && !opts.NestedInSameSession {
if name == "_DAGGER_NESTED_CLIENT_ID" && !opts.ExperimentalPrivilegedNesting {
continue
}

Expand Down Expand Up @@ -1783,22 +1776,15 @@ type ContainerExecOpts struct {
// Redirect the command's standard error to a file in the container
RedirectStderr string `default:""`

// Provide dagger access to the executed command
// Do not use this option unless you trust the command being executed.
// The command being executed WILL BE GRANTED FULL ACCESS TO YOUR HOST FILESYSTEM
// Provide the executed command access back to the Dagger API
ExperimentalPrivilegedNesting bool `default:"false"`

// Grant the process all root capabilities
InsecureRootCapabilities bool `default:"false"`

// (Internal-only) If this exec is for a module function, this digest will be set in the
// grpc context metadata for any api requests back to the engine. It's used by the API
// server to determine which schema to serve and other module context metadata.
ModuleCallerDigest digest.Digest `name:"-"`

// (Internal-only) Used for module function execs to trigger the nested api client to
// be connected back to the same session.
NestedInSameSession bool `name:"-"`
// (Internal-only) If this is a nested exec for a Function call, this should be set
// with the metadata for that call
NestedExecFunctionCall *FunctionCall `name:"-"`
}

type BuildArg struct {
Expand Down
2 changes: 1 addition & 1 deletion core/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (iface *InterfaceType) Install(ctx context.Context, dag *dagql.Server) erro
})
}

res, err := callable.Call(ctx, dagql.CurrentID(ctx), &CallOpts{
res, err := callable.Call(ctx, &CallOpts{
Inputs: callInputs,
ParentVal: runtimeVal.Fields,
})
Expand Down
74 changes: 20 additions & 54 deletions core/modfunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ import (
"github.com/dagger/dagger/core/pipeline"
"github.com/dagger/dagger/dagql"
"github.com/dagger/dagger/dagql/call"
"github.com/dagger/dagger/engine"
"github.com/dagger/dagger/engine/buildkit"
bkgw "github.com/moby/buildkit/frontend/gateway/client"
"github.com/moby/buildkit/util/bklog"
"github.com/opencontainers/go-digest"
ocispecs "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/vito/progrock"
)

type ModuleFunction struct {
Expand Down Expand Up @@ -111,7 +109,7 @@ func (fn *ModuleFunction) recordCall(ctx context.Context) {
analytics.Ctx(ctx).Capture(ctx, "module_call", props)
}

func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallOpts) (t dagql.Typed, rerr error) {
func (fn *ModuleFunction) Call(ctx context.Context, opts *CallOpts) (t dagql.Typed, rerr error) {
mod := fn.mod

lg := bklog.G(ctx).WithField("module", mod.Name()).WithField("function", fn.metadata.Name)
Expand Down Expand Up @@ -164,23 +162,6 @@ func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallO
})
}

callerDigestInputs := []string{}
{
callerIDDigest := caller.Digest() // FIXME(vito) canonicalize, once all that's implemented
callerDigestInputs = append(callerDigestInputs, callerIDDigest.String())
}
if !opts.Cache {
// use the ServerID so that we bust cache once-per-session
clientMetadata, err := engine.ClientMetadataFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get client metadata: %w", err)
}
callerDigestInputs = append(callerDigestInputs, clientMetadata.ServerID)
}

callerDigest := digest.FromString(strings.Join(callerDigestInputs, " "))

ctx = bklog.WithLogger(ctx, bklog.G(ctx).WithField("caller_digest", callerDigest.String()))
bklog.G(ctx).Debug("function call")
defer func() {
bklog.G(ctx).Debug("function call done")
Expand All @@ -189,54 +170,39 @@ func (fn *ModuleFunction) Call(ctx context.Context, caller *call.ID, opts *CallO
}
}()

ctr := fn.runtime

metaDir := NewScratchDirectory(mod.Query, mod.Query.Platform)
ctr, err := ctr.WithMountedDirectory(ctx, modMetaDirPath, metaDir, "", false)
if err != nil {
return nil, fmt.Errorf("failed to mount mod metadata directory: %w", err)
}

// Setup the Exec for the Function call and evaluate it
ctr, err = ctr.WithExec(ctx, ContainerExecOpts{
ModuleCallerDigest: callerDigest,
ExperimentalPrivilegedNesting: true,
NestedInSameSession: true,
})
if err != nil {
return nil, fmt.Errorf("failed to exec function: %w", err)
}

parentJSON, err := json.Marshal(opts.ParentVal)
if err != nil {
return nil, fmt.Errorf("failed to marshal parent value: %w", err)
}

callMeta := &FunctionCall{
Query: fn.root,
Name: fn.metadata.OriginalName,
Parent: parentJSON,
InputArgs: callInputs,
Query: fn.root,
Name: fn.metadata.OriginalName,
Parent: parentJSON,
InputArgs: callInputs,
Module: mod,
Cache: opts.Cache,
SkipSelfSchema: opts.SkipSelfSchema,
}
if fn.objDef != nil {
callMeta.ParentName = fn.objDef.OriginalName
}

var deps *ModDeps
if opts.SkipSelfSchema {
// Only serve the APIs of the deps of this module. This is currently only needed for the special
// case of the function used to get the definition of the module itself (which can't obviously
// be served the API its returning the definition of).
deps = mod.Deps
} else {
// by default, serve both deps and the module's own API to itself
deps = mod.Deps.Prepend(mod)
ctr := fn.runtime

metaDir := NewScratchDirectory(mod.Query, mod.Query.Platform)
ctr, err = ctr.WithMountedDirectory(ctx, modMetaDirPath, metaDir, "", false)
if err != nil {
return nil, fmt.Errorf("failed to mount mod metadata directory: %w", err)
}

err = mod.Query.RegisterFunctionCall(ctx, callerDigest, deps, fn.mod, callMeta,
progrock.FromContext(ctx).Parent)
// Setup the Exec for the Function call and evaluate it
ctr, err = ctr.WithExec(ctx, ContainerExecOpts{
ExperimentalPrivilegedNesting: true,
NestedExecFunctionCall: callMeta,
})
if err != nil {
return nil, fmt.Errorf("failed to register function call: %w", err)
return nil, fmt.Errorf("failed to exec function: %w", err)
}

_, err = ctr.Evaluate(ctx)
Expand Down
2 changes: 1 addition & 1 deletion core/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (mod *Module) Initialize(ctx context.Context, oldSelf dagql.Instance[*Modul
return nil, fmt.Errorf("failed to create module definition function for module %q: %w", mod.Name(), err)
}

result, err := getModDefFn.Call(ctx, newID, &CallOpts{Cache: true, SkipSelfSchema: true})
result, err := getModDefFn.Call(ctx, &CallOpts{Cache: true, SkipSelfSchema: true})
if err != nil {
return nil, fmt.Errorf("failed to call module %q to get functions: %w", mod.Name(), err)
}
Expand Down
9 changes: 4 additions & 5 deletions core/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"sort"

"github.com/dagger/dagger/dagql"
"github.com/dagger/dagger/dagql/call"
"github.com/moby/buildkit/solver/pb"
"github.com/vektah/gqlparser/v2/ast"
)
Expand Down Expand Up @@ -84,7 +83,7 @@ func (t *ModuleObjectType) TypeDef() *TypeDef {
}

type Callable interface {
Call(context.Context, *call.ID, *CallOpts) (dagql.Typed, error)
Call(context.Context, *CallOpts) (dagql.Typed, error)
ReturnType() (ModType, error)
ArgType(argName string) (ModType, error)
}
Expand Down Expand Up @@ -254,7 +253,7 @@ func (obj *ModuleObject) installConstructor(ctx context.Context, dag *dagql.Serv
Value: v,
})
}
return fn.Call(ctx, dagql.CurrentID(ctx), &CallOpts{
return fn.Call(ctx, &CallOpts{
Inputs: callInput,
ParentVal: nil,
})
Expand Down Expand Up @@ -351,7 +350,7 @@ func objFun(ctx context.Context, mod *Module, objDef *ObjectTypeDef, fun *Functi
sort.Slice(opts.Inputs, func(i, j int) bool {
return opts.Inputs[i].Name < opts.Inputs[j].Name
})
return modFun.Call(ctx, dagql.CurrentID(ctx), opts)
return modFun.Call(ctx, opts)
},
}, nil
}
Expand All @@ -362,7 +361,7 @@ type CallableField struct {
Return ModType
}

func (f *CallableField) Call(ctx context.Context, id *call.ID, opts *CallOpts) (dagql.Typed, error) {
func (f *CallableField) Call(ctx context.Context, opts *CallOpts) (dagql.Typed, error) {
val, ok := opts.ParentVal[f.Field.OriginalName]
if !ok {
return nil, fmt.Errorf("field %q not found on object %q", f.Field.Name, opts.ParentVal)
Expand Down