diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5cc4837a0..824e790b8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -112,6 +112,7 @@ jobs: VERSION=$((${{ github.run_number }} + ${{ env.BUILD_INCREMENT }})) go build -ldflags "-X '$REPO_NAME/runner/cmd/runner/version.Version=$VERSION' -extldflags '-static'" -o dstack-runner-$GOOS-$GOARCH $REPO_NAME/runner/cmd/runner go build -ldflags "-X '$REPO_NAME/runner/cmd/shim/version.Version=$VERSION' -extldflags '-static'" -o dstack-shim-$GOOS-$GOARCH $REPO_NAME/runner/cmd/shim + echo $VERSION - uses: actions/upload-artifact@v3 with: name: dstack-runner diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80ecadba6..ee577a901 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,14 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.0 + rev: v0.2.1 hooks: - id: ruff + name: ruff common args: ['--fix'] - id: ruff-format + - repo: https://github.com/golangci/golangci-lint + rev: v1.56.1 + hooks: + - id: golangci-lint-full + entry: bash -c 'cd runner && golangci-lint run -D depguard --presets import,module,unused "$@"' + stages: [manual] diff --git a/docs/docs/reference/pool/index.md b/docs/docs/reference/pool/index.md new file mode 100644 index 000000000..c2ac6e41f --- /dev/null +++ b/docs/docs/reference/pool/index.md @@ -0,0 +1,42 @@ +# dstack pool + +## What is `dstack pool` + +The primary element that enables you to precisely control how compute instances are used is the `dstack pool`. + +- Sometimes the desired instance for the task might not be available. The `dstack pool` will wait for compute instances to become available and, when possible, allocate instances before running tasks on these instances. + +- You need reserved compute instances to work on a constant load. The dstack will pre-allocate ondemand instances and allow you to run tasks on them when they are available. + +- I want to speed up tasks start. Searching for instances and provisioning the runner will take time. When using dstack pool, tasks will be distributed to already running instances. + +- You have your own compute instances. You can connect them to a dstack pool and use them with cloud instances. + +## How to use + +Any task that runs without setted the argument `--pool` by default uses a pool named `default`. + +When you specify a pool name for a task, for example `dstack run --pool mypool` there are two ways the task will be executed: + +- if `mypool` exists, the task will be run on a available instance with the suitable configuration +- if `mypool` does not exist, this pool will be created and the compute instances required for the pool are created and connected to that pool. + +### CLI + +- `dstack pool list` +- `dstack pool create` +- `dstack pool show ` +- `dstack pool add ` +- `dstack pool delete` + +### Instance lifecycle + +- idle time +- reservation policy (instance termination) +- task retry policy + +### Add your own compute instance + +When connecting your own instance, it must have public ip-address for the dstack server to connect. + +To connect you need to pass the ip-addres and ssh credentials to the command `dstack poll add --host HOST --port PORT --ssh-private-key-fileSSH_PRIVATE_KEY_FILE`. diff --git a/runner/cmd/runner/cmd.go b/runner/cmd/runner/cmd.go index ca99ee258..233862835 100644 --- a/runner/cmd/runner/cmd.go +++ b/runner/cmd/runner/cmd.go @@ -56,7 +56,10 @@ func App() { }, }, Action: func(c *cli.Context) error { - start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, logLevel) + err := start(paths.tempDir, paths.homeDir, paths.workingDir, httpPort, logLevel, Version) + if err != nil { + return cli.Exit(err, 1) + } return nil }, }, diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 61142c079..9c83f8f17 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -3,37 +3,46 @@ package main import ( "context" "fmt" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/dstackai/dstack/runner/internal/runner/api" - "github.com/sirupsen/logrus" "io" _ "net/http/pprof" "os" "path/filepath" + + "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/runner/api" + "github.com/sirupsen/logrus" + "github.com/ztrue/tracerr" ) func main() { App() } -func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int) { +func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int, version string) error { if err := os.MkdirAll(tempDir, 0755); err != nil { - log.Error(context.TODO(), "Failed to create temp directory", "err", err) - os.Exit(1) + return tracerr.Errorf("Failed to create temp directory: %w", err) } + defaultLogFile, err := log.CreateAppendFile(filepath.Join(tempDir, "default.log")) if err != nil { - log.Error(context.TODO(), "Failed to create default log file", "err", err) - os.Exit(1) + return tracerr.Errorf("Failed to create default log file: %w", err) } - defer func() { _ = defaultLogFile.Close() }() + defer func() { + err = defaultLogFile.Close() + if err != nil { + tracerr.Print(err) + } + }() + log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile)) log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel)) - server := api.NewServer(tempDir, homeDir, workingDir, fmt.Sprintf(":%d", httpPort)) + server := api.NewServer(tempDir, homeDir, workingDir, fmt.Sprintf(":%d", httpPort), version) log.Trace(context.TODO(), "Starting API server", "port", httpPort) if err := server.Run(); err != nil { - log.Error(context.TODO(), "Server failed", "err", err) + return tracerr.Errorf("Server failed: %w", err) } + + return nil } diff --git a/runner/cmd/runner/version.go b/runner/cmd/runner/version.go index 7b2b1de54..788aadab0 100644 --- a/runner/cmd/runner/version.go +++ b/runner/cmd/runner/version.go @@ -1,4 +1,4 @@ package main // Version A default build-time variable. The value is overridden via ldflags. -var Version = "0.0.1.dev1" +var Version = "0.0.1.dev2" diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 1b6f7ff7b..fa7fd6c19 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -2,35 +2,29 @@ package main import ( "context" + "errors" "fmt" "log" + "net/http" "os" "path/filepath" + "time" - "github.com/dstackai/dstack/runner/internal/gerrors" + "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/api" - "github.com/dstackai/dstack/runner/internal/shim/backends" "github.com/urfave/cli/v2" ) func main() { - var backendName string var args shim.CLIArgs args.Docker.SSHPort = 10022 app := &cli.App{ Name: "dstack-shim", - Usage: "Starts dstack-runner or docker container. Kills the VM on exit.", + Usage: "Starts dstack-runner or docker container.", Version: Version, Flags: []cli.Flag{ - &cli.StringFlag{ - Name: "backend", - Usage: "Cloud backend provider", - Required: true, - Destination: &backendName, - EnvVars: []string{"DSTACK_BACKEND"}, - }, /* Shim Parameters */ &cli.PathFlag{ Name: "home", @@ -85,18 +79,6 @@ func main() { Usage: "Starts docker container and modifies entrypoint", Flags: []cli.Flag{ /* Docker Parameters */ - &cli.BoolFlag{ - Name: "with-auth", - Usage: "Waits for registry credentials", - Destination: &args.Docker.RegistryAuthRequired, - }, - &cli.StringFlag{ - Name: "image", - Usage: "Docker image name", - Required: true, - Destination: &args.Docker.ImageName, - EnvVars: []string{"DSTACK_IMAGE_NAME"}, - }, &cli.BoolFlag{ Name: "keep-container", Usage: "Do not delete container on exit", @@ -112,48 +94,48 @@ func main() { }, Action: func(c *cli.Context) error { if args.Runner.BinaryPath == "" { - if err := args.Download("linux"); err != nil { - return gerrors.Wrap(err) + if err := args.DownloadRunner(); err != nil { + return cli.Exit(err, 1) } - defer func() { _ = os.Remove(args.Runner.BinaryPath) }() } - log.Printf("Backend: %s\n", backendName) args.Runner.TempDir = "/tmp/runner" args.Runner.HomeDir = "/root" args.Runner.WorkingDir = "/workflow" var err error + + // set dstack home path args.Shim.HomeDir, err = getDstackHome(args.Shim.HomeDir) if err != nil { - return gerrors.Wrap(err) + return cli.Exit(err, 1) } log.Printf("Docker: %+v\n", args) - server := api.NewShimServer(fmt.Sprintf(":%d", args.Shim.HTTPPort), args.Docker.RegistryAuthRequired) - return gerrors.Wrap(server.RunDocker(context.TODO(), &args)) - }, - }, - { - Name: "subprocess", - Usage: "Docker-less mode", - Action: func(c *cli.Context) error { - return gerrors.New("not implemented") + dockerRunner, err := shim.NewDockerRunner(args) + if err != nil { + return cli.Exit(err, 1) + } + + address := fmt.Sprintf(":%d", args.Shim.HTTPPort) + shimServer := api.NewShimServer(address, dockerRunner, Version) + + defer func() { + shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelShutdown() + _ = shimServer.HttpServer.Shutdown(shutdownCtx) + }() + + if err := shimServer.HttpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return cli.Exit(err, 1) + } + + return nil }, }, }, } - defer func() { - backend, err := backends.NewBackend(context.TODO(), backendName) - if err != nil { - log.Fatal(err) - } - if err = backend.Terminate(context.TODO()); err != nil { - log.Fatal(err) - } - }() - if err := app.Run(os.Args); err != nil { log.Fatal(err) } @@ -163,9 +145,10 @@ func getDstackHome(flag string) (string, error) { if flag != "" { return flag, nil } + home, err := os.UserHomeDir() if err != nil { - return "", gerrors.Wrap(err) + return "", err } - return filepath.Join(home, ".dstack"), nil + return filepath.Join(home, consts.DSTACK_DIR_PATH), nil } diff --git a/runner/cmd/shim/version.go b/runner/cmd/shim/version.go index c2dfda93c..7aa1d0aae 100644 --- a/runner/cmd/shim/version.go +++ b/runner/cmd/shim/version.go @@ -1,3 +1,3 @@ package main -var Version = "0.0.0dev1" +var Version = "0.0.0dev2" diff --git a/runner/go.mod b/runner/go.mod index 2771f9714..83f89288f 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -3,12 +3,6 @@ module github.com/dstackai/dstack/runner go 1.19 require ( - cloud.google.com/go/compute v1.23.0 - github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1 - github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4 v4.2.1 - github.com/aws/aws-sdk-go-v2/config v1.18.39 - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11 - github.com/aws/aws-sdk-go-v2/service/ec2 v1.118.0 github.com/bluekeyes/go-gitdiff v0.6.0 github.com/creack/pty v1.1.18 github.com/docker/docker v24.0.6+incompatible @@ -18,28 +12,15 @@ require ( github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.1 github.com/urfave/cli/v2 v2.25.7 + github.com/ztrue/tracerr v0.4.0 golang.org/x/crypto v0.14.0 ) require ( - cloud.google.com/go/compute/metadata v0.2.3 // indirect dario.cat/mergo v1.0.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect - github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect github.com/ProtonMail/go-crypto v0.0.0-20230717121422-5aa5874ade95 // indirect github.com/acomagu/bufpipe v1.0.4 // indirect - github.com/aws/aws-sdk-go-v2 v1.21.0 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.13.37 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.3.42 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.13.6 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.6 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.21.5 // indirect - github.com/aws/smithy-go v1.14.2 // indirect github.com/cloudflare/circl v1.3.3 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -49,27 +30,18 @@ require ( github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.4.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v5 v5.0.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/google/s2a-go v0.1.4 // indirect - github.com/google/uuid v1.3.0 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect - github.com/googleapis/gax-go/v2 v2.11.0 // indirect github.com/h2non/filetype v1.1.3 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/juju/errors v0.0.0-20181118221551-089d3ea4e4d5 // indirect github.com/juju/loggo v1.0.0 // indirect github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/klauspost/compress v1.15.13 // indirect - github.com/kylelemons/godebug v1.1.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect - github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect @@ -78,22 +50,13 @@ require ( github.com/ulikunitz/xz v0.5.11 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect - go.opencensus.io v0.24.0 // indirect golang.org/x/mod v0.13.0 // indirect golang.org/x/net v0.16.0 // indirect - golang.org/x/oauth2 v0.8.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.14.0 // indirect - google.golang.org/api v0.126.0 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect - google.golang.org/grpc v1.55.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gotest.tools/v3 v3.5.0 // indirect ) diff --git a/runner/go.sum b/runner/go.sum index 97afdac9f..6ac28bc9c 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -1,27 +1,7 @@ cloud.google.com/go v0.16.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.110.2 h1:sdFPBr6xG9/wkBbfhmUz/JmZC7X6LavQgcrVINrKiVA= -cloud.google.com/go/compute v1.23.0 h1:tP41Zoavr8ptEqaW6j+LQOnyBBhO7OkOMAGrgLopTwY= -cloud.google.com/go/compute v1.23.0/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM= -cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= -cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1 h1:/iHxaJhsFr0+xVFfbMr5vxz848jyiWuIEDhYq3y5odY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1 h1:LNHhpdK7hzUcx/k1LIcuh5k7k1LGIWLQfCjaneSj7Fc= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1/go.mod h1:uE9zaUfEQT/nbQjVi2IblCG9iaLtZsuYZ8ne+PuQ02M= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4 v4.2.1 h1:UPeCRD+XY7QlaGQte2EVI2iOcWvUYA2XY8w5T/8v0NQ= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4 v4.2.1/go.mod h1:oGV6NlB0cvi1ZbYRR2UN44QHxWFyGk+iylgD0qaMXjA= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal v1.1.2 h1:mLY+pNLjCUeKhgnAJWAKhEUQM+RJQo2H1fuGSw1Ky1E= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.0.0 h1:nBy98uKOIfun5z6wx6jwWLrULcM0+cjBalBFZlEZ7CA= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.0.0 h1:ECsQtyERDVz3NP3kvDOTLvbQhqWp/x9EsGKtb4ogUr8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= -github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1 h1:WpB/QDNLpMw72xHJc34BNNykqSOeEJDAWkhf0u12/Jk= -github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= @@ -31,50 +11,14 @@ github.com/ProtonMail/go-crypto v0.0.0-20230717121422-5aa5874ade95/go.mod h1:EjA github.com/acomagu/bufpipe v1.0.4 h1:e3H4WUzM3npvo5uv95QuJM3cQspFNtFBzvJ2oNjKIDQ= github.com/acomagu/bufpipe v1.0.4/go.mod h1:mxdxdup/WdsKVreO5GpW4+M/1CE2sMG4jeGJ2sYmHc4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= -github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/arduino/go-paths-helper v1.2.0 h1:qDW93PR5IZUN/jzO4rCtexiwF8P4OIcOmcSgAYLZfY4= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= -github.com/aws/aws-sdk-go-v2 v1.21.0 h1:gMT0IW+03wtYJhRqTVYn0wLzwdnK9sRMcxmtfGzRdJc= -github.com/aws/aws-sdk-go-v2 v1.21.0/go.mod h1:/RfNgGmRxI+iFOB1OeJUyxiU+9s88k3pfHvDagGEp0M= -github.com/aws/aws-sdk-go-v2/config v1.18.39 h1:oPVyh6fuu/u4OiW4qcuQyEtk7U7uuNBmHmJSLg1AJsQ= -github.com/aws/aws-sdk-go-v2/config v1.18.39/go.mod h1:+NH/ZigdPckFpgB1TRcRuWCB/Kbbvkxc/iNAKTq5RhE= -github.com/aws/aws-sdk-go-v2/credentials v1.13.37 h1:BvEdm09+ZEh2XtN+PVHPcYwKY3wIeB6pw7vPRM4M9/U= -github.com/aws/aws-sdk-go-v2/credentials v1.13.37/go.mod h1:ACLrdkd4CLZyXOghZ8IYumQbcooAcp2jo/s2xsFH8IM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11 h1:uDZJF1hu0EVT/4bogChk8DyjSF6fof6uL/0Y26Ma7Fg= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.11/go.mod h1:TEPP4tENqBGO99KwVpV9MlOX4NSrSLP8u3KRy2CDwA8= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41 h1:22dGT7PneFMx4+b3pz7lMTRyN8ZKH7M2cW4GP9yUS2g= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.41/go.mod h1:CrObHAuPneJBlfEJ5T3szXOUkLEThaGfvnhTf33buas= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35 h1:SijA0mgjV8E+8G45ltVHs0fvKpTj8xmZJ3VwhGKtUSI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.35/go.mod h1:SJC1nEVVva1g3pHAIdCp7QsRIkMmLAgoDquQ9Rr8kYw= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.42 h1:GPUcE/Yq7Ur8YSUk6lVkoIMWnJNO0HT18GUzCWCgCI0= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.42/go.mod h1:rzfdUlfA+jdgLDmPKjd3Chq9V7LVLYo1Nz++Wb91aRo= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.118.0 h1:ueSJS07XpOwCFhYTHh/Jjw856+U+u0Dv5LIIPOB1/Ns= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.118.0/go.mod h1:0FhI2Rzcv5BNM3dNnbcCx2qa2naFZoAidJi11cQgzL0= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35 h1:CdzPW9kKitgIiLV1+MHobfR5Xg25iYnyzWZhyQuSlDI= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.35/go.mod h1:QGF2Rs33W5MaN9gYdEQOBBFPLwTZkEhRwI33f7KIG0o= -github.com/aws/aws-sdk-go-v2/service/sso v1.13.6 h1:2PylFCfKCEDv6PeSN09pC/VUiRd10wi1VfHG5FrW0/g= -github.com/aws/aws-sdk-go-v2/service/sso v1.13.6/go.mod h1:fIAwKQKBFu90pBxx07BFOMJLpRUGu8VOzLJakeY+0K4= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.6 h1:pSB560BbVj9ZlJZF4WYj5zsytWHWKxg+NgyGV4B2L58= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.6/go.mod h1:yygr8ACQRY2PrEcy3xsUI357stq2AxnFM6DIsR9lij4= -github.com/aws/aws-sdk-go-v2/service/sts v1.21.5 h1:CQBFElb0LS8RojMJlxRSo/HXipvTZW2S44Lt9Mk2aYQ= -github.com/aws/aws-sdk-go-v2/service/sts v1.21.5/go.mod h1:VC7JDqsqiwXukYEDjoHh9U0fOJtNWh04FPQz4ct4GGU= -github.com/aws/smithy-go v1.14.2 h1:MJU9hqBGbvWZdApzpvoF2WAIJDbtjK2NDJSiJP7HblQ= -github.com/aws/smithy-go v1.14.2/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/bluekeyes/go-gitdiff v0.6.0 h1:zyDBSR/o1axUl4lD08EWkXO3I834tBimmGUB0mhrvhQ= github.com/bluekeyes/go-gitdiff v0.6.0/go.mod h1:QpfYYO1E0fTVHVZAZKiRjtSGY9823iCdvGXBcEzHGbM= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudflare/circl v1.3.3 h1:fE/Qz0QdIGqeWfnwq0RE0R7MI51s0M2E4Ga9kq5AEMs= github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= -github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/codeclysm/extract/v3 v3.1.0 h1:z14FpkRizce3HNHsqJoZWwj0ovzZ2hiIkmT96FQS3j8= github.com/codeclysm/extract/v3 v3.1.0/go.mod h1:ZJi80UG2JtfHqJI+lgJSCACttZi++dHxfWuPaMhlOfQ= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= @@ -85,7 +29,6 @@ github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker v24.0.6+incompatible h1:hceabKCtUgDqPu+qm0NgsaXf28Ljf4/pWFL7xjWWDgE= @@ -97,15 +40,8 @@ github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDD github.com/elazarl/goproxy v0.0.0-20221015165544-a0805db90819 h1:RIB4cRk+lBqKK3Oy0r2gRX4ui7tuhiZq2SuTtTCi0/0= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.3-0.20170329110642-4da3e2cfbabc/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/garyburd/redigo v1.1.1-0.20170914051019-70e1b1943d4f/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= @@ -117,67 +53,25 @@ github.com/go-git/go-git/v5 v5.8.1/go.mod h1:FHFuoD6yGz5OSKEBK+aWN9Oah0q54Jxl0ab github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= -github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f h1:16RtHeWGkJMc80Etb8RPCcKevXGldr57+LOyZt8zOlg= github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f/go.mod h1:ijRvpgDJDI262hYq/IQVYgf8hd8IHUs93Ol0kvMBAx4= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/lint v0.0.0-20170918230701-e5d664eb928e/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.1.1-0.20171103154506-982329095285/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/s2a-go v0.1.4 h1:1kZ/sQM3srePvKs3tXAvQzo66XfcReoqFpIpIccE7Oc= -github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.2.3 h1:yk9/cqRKtT9wXZSsRH9aurXEpJX+U6FLtpYTdC3R06k= -github.com/googleapis/enterprise-certificate-proxy v0.2.3/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k= github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= -github.com/googleapis/gax-go/v2 v2.11.0 h1:9V9PWXEsWnPpQhu/PeQIkS4eGzMlTLGgt80cUUI8Ki4= -github.com/googleapis/gax-go/v2 v2.11.0/go.mod h1:DxmR61SGKkGLa2xigwuZIQpkCI2S5iydzRfb3peWZJI= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20170920190843-316c5e0ff04e/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= -github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/hashicorp/hcl v0.0.0-20170914154624-68e816d1c783/go.mod h1:oZtUIOe8dh44I2q6ScRibXws4Ajl+d+nod3AaR9vL5w= github.com/inconshreveable/log15 v0.0.0-20170622235902-74a0988b5f80/go.mod h1:cOaXtrgN4ScfRrD9Bre7U1thNq5RtJ8ZoP4iXVGRj6o= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= github.com/juju/errors v0.0.0-20181118221551-089d3ea4e4d5 h1:rhqTjzJlm7EbkELJDKMTU7udov+Se0xZkWmugr6zGok= github.com/juju/errors v0.0.0-20181118221551-089d3ea4e4d5/go.mod h1:W54LbzXuIE0boCoNJfwqpmkKJ1O4TCTZMetAt6jGk7Q= @@ -198,8 +92,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lunixbochs/vtclean v0.0.0-20160125035106-4fbf7632a2c6/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/magiconair/properties v1.7.4-0.20170902060319-8d7837e64d3c/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A= @@ -221,14 +113,10 @@ github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zM github.com/pelletier/go-toml v1.0.1-0.20170904195809-1d6b12b7cb29/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= -github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -249,7 +137,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -266,41 +153,28 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= -go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +github.com/ztrue/tracerr v0.4.0 h1:vT5PFxwIGs7rCg9ZgJ/y0NmOpJkPCPFK8x0vVIYzd04= +github.com/ztrue/tracerr v0.4.0/go.mod h1:PaFfYlas0DfmXNpo7Eay4MFhZUONqvXM+T2HyGPpngk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -310,31 +184,21 @@ golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos= golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20170912212905-13449ad91cb2/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.8.0 h1:6dkIjl3j3LtZ/O3sTgZTMsLKSftL/B8Zgq4huOIIUu8= -golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/sync v0.0.0-20170517211232-f52d1811a629/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -355,19 +219,13 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20170424234030-8be79e1e0910 h1:bCMaBn7ph495H+x72gEvgcv+mDRd9dElbzo/mVCMxX4= golang.org/x/time v0.0.0-20170424234030-8be79e1e0910/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= @@ -380,48 +238,9 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.0.0-20170921000349-586095a6e407/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.126.0 h1:q4GJq+cAdMAC7XP7njvQ4tvohGLiSlytuL4BQxbIZ+o= -google.golang.org/api v0.126.0/go.mod h1:mBwVAtz+87bEN6CbA1GtZPDOqY2R5ONPqJeIlvyo4Aw= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20170918111702-1e559d0a00ee/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc h1:8DyZCyvI8mE1IdLy/60bS+52xfymkE72wv1asokgtao= -google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:xZnkP7mREFX5MORlOPEzLMr+90PPZQ2QWzrVTWfAq64= -google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc h1:kVKPf/IiYSBWEWtkIn6wZXwWGCnLKcC8oWfZvXjsGnM= -google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc h1:XSJ8Vk1SWuNr8S18z1NZSziL0CPIXLCCMDOEFtHBOFc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= google.golang.org/grpc v1.2.1-0.20170921194603-d4b75ebd4f9f/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= -google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ= -google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag= -google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20160105164936-4f90aeace3a2/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -432,14 +251,11 @@ gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3M gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/runner/internal/api/common.go b/runner/internal/api/common.go index 39bbe1b8c..7c4ceba8a 100644 --- a/runner/internal/api/common.go +++ b/runner/internal/api/common.go @@ -4,11 +4,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/golang/gddo/httputil/header" "io" "net/http" "strings" + + "github.com/dstackai/dstack/runner/internal/log" + "github.com/golang/gddo/httputil/header" ) type Error struct { diff --git a/runner/internal/common/interpolator.go b/runner/internal/common/interpolator.go index 68b22e6b8..733114181 100644 --- a/runner/internal/common/interpolator.go +++ b/runner/internal/common/interpolator.go @@ -3,9 +3,10 @@ package common import ( "context" "fmt" + "strings" + "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/log" - "strings" ) const ( diff --git a/runner/internal/common/interpolator_test.go b/runner/internal/common/interpolator_test.go index fbb8d6667..e14a24874 100644 --- a/runner/internal/common/interpolator_test.go +++ b/runner/internal/common/interpolator_test.go @@ -2,8 +2,9 @@ package common import ( "context" - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestPlainText(t *testing.T) { diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go index 36e441d4a..bf3eeb291 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/executor/base.go @@ -2,6 +2,7 @@ package executor import ( "context" + "github.com/dstackai/dstack/runner/internal/schemas" ) diff --git a/runner/internal/executor/exec_test.go b/runner/internal/executor/exec_test.go index fc075e3b7..841f4e6b1 100644 --- a/runner/internal/executor/exec_test.go +++ b/runner/internal/executor/exec_test.go @@ -1,8 +1,9 @@ package executor import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestJoinRelPath(t *testing.T) { diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index c6b524574..ff61910a6 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -5,14 +5,15 @@ import ( "bytes" "context" "fmt" - "github.com/dstackai/dstack/runner/internal/schemas" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "io" "os" "path/filepath" "testing" "time" + + "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // todo test get history diff --git a/runner/internal/executor/logs.go b/runner/internal/executor/logs.go index 71228b3db..807071eeb 100644 --- a/runner/internal/executor/logs.go +++ b/runner/internal/executor/logs.go @@ -1,8 +1,9 @@ package executor import ( - "github.com/dstackai/dstack/runner/internal/schemas" "sync" + + "github.com/dstackai/dstack/runner/internal/schemas" ) type appendWriter struct { diff --git a/runner/internal/executor/timestamp.go b/runner/internal/executor/timestamp.go index d93a649b3..a9463c04c 100644 --- a/runner/internal/executor/timestamp.go +++ b/runner/internal/executor/timestamp.go @@ -2,9 +2,10 @@ package executor import ( "context" - "github.com/dstackai/dstack/runner/internal/log" "sync" "time" + + "github.com/dstackai/dstack/runner/internal/log" ) type MonotonicTimestamp struct { diff --git a/runner/internal/log/log.go b/runner/internal/log/log.go index 5fd03fd51..99478a8f9 100644 --- a/runner/internal/log/log.go +++ b/runner/internal/log/log.go @@ -3,10 +3,11 @@ package log import ( "context" "fmt" - "github.com/dstackai/dstack/runner/internal/gerrors" - "github.com/sirupsen/logrus" "io" "os" + + "github.com/dstackai/dstack/runner/internal/gerrors" + "github.com/sirupsen/logrus" ) type loggerKey struct{} diff --git a/runner/internal/repo/manager.go b/runner/internal/repo/manager.go index 0f5a9d254..3d2ad42d6 100644 --- a/runner/internal/repo/manager.go +++ b/runner/internal/repo/manager.go @@ -3,9 +3,9 @@ package repo import ( "context" "fmt" - "github.com/dstackai/dstack/runner/internal/gerrors" "os" + "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/log" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing" diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 1b0cef495..74da7225c 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -20,13 +20,16 @@ func (s *Server) healthcheckGetHandler(w http.ResponseWriter, r *http.Request) ( defer s.executor.RUnlock() return &schemas.HealthcheckResponse{ Service: "dstack-runner", + Version: s.version, }, nil } func (s *Server) submitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.executor.Lock() defer s.executor.Unlock() - if s.executor.GetRunnerState() != executor.WaitSubmit { + state := s.executor.GetRunnerState() + if state != executor.WaitSubmit { + log.Warning(r.Context(), "Executor doesn't wait submit", "current_state", state) return nil, &api.Error{Status: http.StatusConflict} } diff --git a/runner/internal/runner/api/http_test.go b/runner/internal/runner/api/http_test.go index c0b877068..df5ef8575 100644 --- a/runner/internal/runner/api/http_test.go +++ b/runner/internal/runner/api/http_test.go @@ -1,4 +1,45 @@ package api -// todo test 409 on wrong requests order -// todo test submit wait timeout +import ( + "context" + "net/http/httptest" + "strings" + "testing" + + common "github.com/dstackai/dstack/runner/internal/api" + "github.com/dstackai/dstack/runner/internal/shim" + "github.com/dstackai/dstack/runner/internal/shim/api" +) + +type DummyRunner struct { + State shim.RunnerStatus +} + +func (ds DummyRunner) GetState() shim.RunnerStatus { + return ds.State +} + +func (ds DummyRunner) Run(context.Context, shim.DockerImageConfig) error { + return nil +} + +func TestHealthcheck(t *testing.T) { + + request := httptest.NewRequest("GET", "/api/healthcheck", nil) + responseRecorder := httptest.NewRecorder() + + server := api.NewShimServer(":12345", DummyRunner{}, "0.0.1.dev2") + + f := common.JSONResponseHandler("GET", server.HealthcheckGetHandler) + f(responseRecorder, request) + + if responseRecorder.Code != 200 { + t.Errorf("Want status '%d', got '%d'", 200, responseRecorder.Code) + } + + expected := "{\"service\":\"dstack-shim\",\"version\":\"0.0.1.dev2\"}" + + if strings.TrimSpace(responseRecorder.Body.String()) != expected { + t.Errorf("Want '%s', got '%s'", expected, responseRecorder.Body.String()) + } +} diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 5d4952246..9c469e070 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -30,9 +30,11 @@ type Server struct { executor executor.Executor cancelRun context.CancelFunc + + version string } -func NewServer(tempDir string, homeDir string, workingDir string, address string) *Server { +func NewServer(tempDir string, homeDir string, workingDir string, address string, version string) *Server { mux := http.NewServeMux() s := &Server{ srv: &http.Server{ @@ -51,6 +53,8 @@ func NewServer(tempDir string, homeDir string, workingDir string, address string logsWaitDuration: 30 * time.Second, executor: executor.NewRunExecutor(tempDir, homeDir, workingDir), + + version: version, } mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.healthcheckGetHandler)) mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.submitPostHandler)) diff --git a/runner/internal/runner/api/submit_test.go b/runner/internal/runner/api/submit_test.go new file mode 100644 index 000000000..9ff24efd0 --- /dev/null +++ b/runner/internal/runner/api/submit_test.go @@ -0,0 +1,47 @@ +//go:build !race + +package api + +import ( + "net/http/httptest" + "strings" + "testing" + + common "github.com/dstackai/dstack/runner/internal/api" + "github.com/dstackai/dstack/runner/internal/shim" + "github.com/dstackai/dstack/runner/internal/shim/api" +) + +func TestSubmit(t *testing.T) { + + request := httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) + responseRecorder := httptest.NewRecorder() + + dummyRunner := DummyRunner{} + dummyRunner.State = shim.Pending + + server := api.NewShimServer(":12340", &dummyRunner, "0.0.1.dev2") + + firstSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) + firstSubmitPost(responseRecorder, request) + + if responseRecorder.Code != 200 { + t.Errorf("Want status '%d', got '%d'", 200, responseRecorder.Code) + } + + t.Logf("%v", responseRecorder.Result()) + + dummyRunner.State = shim.Pulling + + request = httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) + responseRecorder = httptest.NewRecorder() + + secondSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) + secondSubmitPost(responseRecorder, request) + + t.Logf("%v", responseRecorder.Result()) + + if responseRecorder.Code != 409 { + t.Errorf("Want status '%d', got '%d'", 409, responseRecorder.Code) + } +} diff --git a/runner/internal/runner/api/ws.go b/runner/internal/runner/api/ws.go index 26a52cd32..cade1170a 100644 --- a/runner/internal/runner/api/ws.go +++ b/runner/internal/runner/api/ws.go @@ -2,10 +2,11 @@ package api import ( "context" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/gorilla/websocket" "net/http" "time" + + "github.com/dstackai/dstack/runner/internal/log" + "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index bd3150849..3253d840c 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -81,6 +81,7 @@ type Gateway struct { type HealthcheckResponse struct { Service string `json:"service"` + Version string `json:"version"` } func (d *RepoData) FormatURL(format string) string { diff --git a/runner/internal/shim/api/http.go b/runner/internal/shim/api/http.go index 160bdf2fa..5508dfebf 100644 --- a/runner/internal/shim/api/http.go +++ b/runner/internal/shim/api/http.go @@ -1,57 +1,53 @@ package api import ( - "encoding/base64" - "encoding/json" + "context" + "fmt" "log" "net/http" - "github.com/docker/docker/api/types/registry" "github.com/dstackai/dstack/runner/internal/api" "github.com/dstackai/dstack/runner/internal/shim" ) -func (s *ShimServer) healthcheckGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (s *ShimServer) HealthcheckGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() return &HealthcheckResponse{ Service: "dstack-shim", + Version: s.version, }, nil } -func (s *ShimServer) registryAuthPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() - if s.state != shim.WaitRegistryAuth { + if s.runner.GetState() != shim.Pending { return nil, &api.Error{Status: http.StatusConflict} } - var body RegistryAuthBody + var body DockerTaskBody if err := api.DecodeJSONBody(w, r, &body, true); err != nil { log.Println("Failed to decode submit body", "err", err) return nil, err } - authConfig := registry.AuthConfig{ - Username: body.Username, - Password: body.Password, - } - encodedConfig, err := json.Marshal(authConfig) - if err != nil { - log.Println("Failed to encode auth config", "err", err) - return nil, err - } - s.registryAuth <- base64.URLEncoding.EncodeToString(encodedConfig) + go func(taskParams shim.DockerImageConfig) { + err := s.runner.Run(context.TODO(), taskParams) + if err != nil { + fmt.Printf("failed Run %v", err) + } + }(body.TaskParams()) return nil, nil } -func (s *ShimServer) pullGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (s *ShimServer) PullGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() return &PullResponse{ - State: s.state, + State: string(s.runner.GetState()), }, nil } diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 069b1aa66..8f86b8a21 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -1,14 +1,27 @@ package api -type RegistryAuthBody struct { - Username string `json:"username"` - Password string `json:"password"` +import "github.com/dstackai/dstack/runner/internal/shim" + +type DockerTaskBody struct { + Username string `json:"username"` + Password string `json:"password"` + ImageName string `json:"image_name"` } type HealthcheckResponse struct { Service string `json:"service"` + Version string `json:"version"` } type PullResponse struct { State string `json:"state"` } + +func (ra DockerTaskBody) TaskParams() shim.DockerImageConfig { + res := shim.DockerImageConfig{ + ImageName: ra.ImageName, + Username: ra.Username, + Password: ra.Password, + } + return res +} diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 6ef8a3a13..ce1bcfd59 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -2,65 +2,41 @@ package api import ( "context" - "errors" "net/http" "sync" - "time" "github.com/dstackai/dstack/runner/internal/api" - "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/shim" ) +type TaskRunner interface { + Run(context.Context, shim.DockerImageConfig) error + GetState() shim.RunnerStatus +} + type ShimServer struct { - srv *http.Server + HttpServer *http.Server + mu sync.RWMutex + + runner TaskRunner - mu sync.RWMutex - registryAuth chan string - state string + version string } -func NewShimServer(address string, registryAuthRequired bool) *ShimServer { +func NewShimServer(address string, runner TaskRunner, version string) *ShimServer { mux := http.NewServeMux() s := &ShimServer{ - srv: &http.Server{ + HttpServer: &http.Server{ Addr: address, Handler: mux, }, - registryAuth: make(chan string, 1), - state: shim.WaitRegistryAuth, - } - if registryAuthRequired { - mux.HandleFunc("/api/registry_auth", api.JSONResponseHandler("POST", s.registryAuthPostHandler)) - } else { - close(s.registryAuth) // no credentials ever would be sent + runner: runner, + + version: version, } - mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.healthcheckGetHandler)) - mux.HandleFunc("/api/pull", api.JSONResponseHandler("GET", s.pullGetHandler)) + mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.SubmitPostHandler)) + mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.HealthcheckGetHandler)) + mux.HandleFunc("/api/pull", api.JSONResponseHandler("GET", s.PullGetHandler)) return s } - -func (s *ShimServer) RunDocker(ctx context.Context, params shim.DockerParameters) error { - go func() { - if err := s.srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - panic(err) - } - }() - defer func() { - shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second) - defer cancelShutdown() - _ = s.srv.Shutdown(shutdownCtx) - }() - return gerrors.Wrap(shim.RunDocker(ctx, params, s)) -} - -func (s *ShimServer) GetRegistryAuth() <-chan string { - return s.registryAuth -} - -func (s *ShimServer) SetState(state string) { - s.mu.Lock() - defer s.mu.Unlock() - s.state = state -} diff --git a/runner/internal/shim/backends/aws.go b/runner/internal/shim/backends/aws.go deleted file mode 100644 index 2b802e275..000000000 --- a/runner/internal/shim/backends/aws.go +++ /dev/null @@ -1,74 +0,0 @@ -package backends - -import ( - "bytes" - "context" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" - "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/dstackai/dstack/runner/internal/gerrors" - "io" -) - -type AWSBackend struct { - region string - instanceId string - spot bool -} - -func init() { - register("aws", NewAWSBackend) -} - -func NewAWSBackend(ctx context.Context) (Backend, error) { - cfg, err := config.LoadDefaultConfig(ctx) - if err != nil { - return nil, gerrors.Wrap(err) - } - - client := imds.NewFromConfig(cfg) - region, err := client.GetRegion(ctx, &imds.GetRegionInput{}) - if err != nil { - return nil, gerrors.Wrap(err) - } - lifecycle, err := getAWSMetadata(ctx, client, "instance-life-cycle") - if err != nil { - return nil, gerrors.Wrap(err) - } - instanceId, err := getAWSMetadata(ctx, client, "instance-id") - if err != nil { - return nil, gerrors.Wrap(err) - } - - return &AWSBackend{ - region: region.Region, - instanceId: instanceId, - spot: lifecycle == "spot", - }, nil -} - -func (b *AWSBackend) Terminate(ctx context.Context) error { - cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(b.region)) - if err != nil { - return gerrors.Wrap(err) - } - client := ec2.NewFromConfig(cfg) - _, err = client.TerminateInstances(ctx, &ec2.TerminateInstancesInput{ - InstanceIds: []string{b.instanceId}, - }) - return gerrors.Wrap(err) -} - -func getAWSMetadata(ctx context.Context, client *imds.Client, path string) (string, error) { - resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{ - Path: path, - }) - if err != nil { - return "", gerrors.Wrap(err) - } - var b bytes.Buffer - if _, err = io.Copy(&b, resp.Content); err != nil { - return "", err - } - return b.String(), nil -} diff --git a/runner/internal/shim/backends/azure.go b/runner/internal/shim/backends/azure.go deleted file mode 100644 index a06515e86..000000000 --- a/runner/internal/shim/backends/azure.go +++ /dev/null @@ -1,84 +0,0 @@ -package backends - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v4" - "github.com/dstackai/dstack/runner/internal/gerrors" -) - -type AzureBackend struct { - subscriptionId string - resourceGroup string - vmName string -} - -func init() { - register("azure", NewAzureBackend) -} - -func NewAzureBackend(ctx context.Context) (Backend, error) { - metadata, err := getAzureMetadata(ctx, nil) - if err != nil { - return nil, gerrors.Wrap(err) - } - return &AzureBackend{ - subscriptionId: metadata.SubscriptionId, - resourceGroup: metadata.ResourceGroupName, - vmName: metadata.Name, - }, nil -} - -func (b *AzureBackend) Terminate(ctx context.Context) error { - credential, err := azidentity.NewManagedIdentityCredential(nil) - if err != nil { - return gerrors.Wrap(err) - } - computeClient, err := armcompute.NewVirtualMachinesClient(b.subscriptionId, credential, nil) - if err != nil { - return gerrors.Wrap(err) - } - _, err = computeClient.BeginDelete(ctx, b.resourceGroup, b.vmName, nil) - return gerrors.Wrap(err) -} - -type AzureComputeInstanceMetadata struct { - SubscriptionId string `json:"subscriptionId"` - ResourceGroupName string `json:"resourceGroupName"` - Name string `json:"name"` -} - -type AzureInstanceMetadata struct { - Compute AzureComputeInstanceMetadata `json:"compute"` -} - -func getAzureMetadata(ctx context.Context, url *string) (*AzureComputeInstanceMetadata, error) { - baseURL := "http://169.254.169.254" - if url != nil { - baseURL = *url - } - req, err := http.NewRequestWithContext( - ctx, - http.MethodGet, - fmt.Sprintf("%s/metadata/instance?api-version=2021-02-01", baseURL), - nil, - ) - if err != nil { - return nil, gerrors.Wrap(err) - } - req.Header.Add("Metadata", "true") - res, err := http.DefaultClient.Do(req) - if err != nil { - return nil, gerrors.Wrap(err) - } - decoder := json.NewDecoder(res.Body) - var metadata AzureInstanceMetadata - if err = decoder.Decode(&metadata); err != nil { - return nil, gerrors.Wrap(err) - } - return &metadata.Compute, nil -} diff --git a/runner/internal/shim/backends/azure_test.go b/runner/internal/shim/backends/azure_test.go deleted file mode 100644 index 8ea63d20d..000000000 --- a/runner/internal/shim/backends/azure_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package backends - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGetsAzureMetadata(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte( - `{"compute": - { - "subscriptionId":"test_subscription", - "resourceGroupName":"test_group", - "name":"test_vm" - } - }`, - )) - })) - defer server.Close() - metadata, err := getAzureMetadata(context.TODO(), &server.URL) - assert.Equal(t, nil, err) - assert.Equal(t, AzureComputeInstanceMetadata{ - SubscriptionId: "test_subscription", - ResourceGroupName: "test_group", - Name: "test_vm", - }, *metadata) -} diff --git a/runner/internal/shim/backends/backends.go b/runner/internal/shim/backends/backends.go deleted file mode 100644 index 64e7d3728..000000000 --- a/runner/internal/shim/backends/backends.go +++ /dev/null @@ -1,32 +0,0 @@ -package backends - -import ( - "context" - "github.com/dstackai/dstack/runner/internal/gerrors" - "sync" -) - -type Backend interface { - Terminate(context.Context) error -} - -type BackendFactory func(ctx context.Context) (Backend, error) - -var backends = make(map[string]BackendFactory) -var mu = sync.Mutex{} - -func NewBackend(ctx context.Context, name string) (Backend, error) { - mu.Lock() - defer mu.Unlock() - factory, ok := backends[name] - if !ok { - return nil, gerrors.Newf("unknown backend %s", name) - } - return factory(ctx) -} - -func register(name string, factory BackendFactory) { - mu.Lock() - defer mu.Unlock() - backends[name] = factory -} diff --git a/runner/internal/shim/backends/gcp.go b/runner/internal/shim/backends/gcp.go deleted file mode 100644 index 57f668b52..000000000 --- a/runner/internal/shim/backends/gcp.go +++ /dev/null @@ -1,71 +0,0 @@ -package backends - -import ( - compute "cloud.google.com/go/compute/apiv1" - "cloud.google.com/go/compute/apiv1/computepb" - "context" - "fmt" - "github.com/dstackai/dstack/runner/internal/gerrors" - "io" - "net/http" - "strings" -) - -type GCPBackend struct { - instanceName string - project string - zone string -} - -func init() { - register("gcp", NewGCPBackend) -} - -func NewGCPBackend(ctx context.Context) (Backend, error) { - instanceName, err := getGCPMetadata(ctx, "/instance/name") - if err != nil { - return nil, gerrors.Wrap(err) - } - projectZone, err := getGCPMetadata(ctx, "/instance/zone") - if err != nil { - return nil, gerrors.Wrap(err) - } - // Parse `projects//zones/` - parts := strings.Split(projectZone, "/") - return &GCPBackend{ - instanceName: instanceName, - project: parts[1], - zone: parts[3], - }, nil -} - -func (b *GCPBackend) Terminate(ctx context.Context) error { - instancesClient, err := compute.NewInstancesRESTClient(ctx) - if err != nil { - return nil - } - req := &computepb.DeleteInstanceRequest{ - Instance: b.instanceName, - Project: b.project, - Zone: b.zone, - } - _, err = instancesClient.Delete(ctx, req) - return gerrors.Wrap(err) -} - -func getGCPMetadata(ctx context.Context, path string) (string, error) { - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://metadata.google.internal/computeMetadata/v1%s", path), nil) - if err != nil { - return "", gerrors.Wrap(err) - } - req.Header.Add("Metadata-Flavor", "Google") - res, err := http.DefaultClient.Do(req.WithContext(ctx)) - if err != nil { - return "", gerrors.Wrap(err) - } - body, err := io.ReadAll(res.Body) - if err != nil { - return "", gerrors.Wrap(err) - } - return string(body), nil -} diff --git a/runner/internal/shim/backends/lambda.go b/runner/internal/shim/backends/lambda.go deleted file mode 100644 index d2a6325e4..000000000 --- a/runner/internal/shim/backends/lambda.go +++ /dev/null @@ -1,82 +0,0 @@ -package backends - -import ( - "bytes" - "context" - "encoding/json" - "net/http" - "os" - - "github.com/dstackai/dstack/runner/internal/gerrors" -) - -const LAMBDA_API_URL = "https://cloud.lambdalabs.com/api/v1" - -type LambdaAPIClient struct { - apiKey string -} - -type TerminateInstanceRequest struct { - InstanceIDs []string `json:"instance_ids"` -} - -func NewLambdaAPIClient(apiKey string) *LambdaAPIClient { - return &LambdaAPIClient{apiKey: apiKey} -} - -func (client *LambdaAPIClient) TerminateInstance(ctx context.Context, instanceIDs []string) error { - body, err := json.Marshal(TerminateInstanceRequest{InstanceIDs: instanceIDs}) - if err != nil { - return gerrors.Wrap(err) - } - req, err := http.NewRequest("POST", LAMBDA_API_URL+"/instance-operations/terminate", bytes.NewReader(body)) - if err != nil { - return gerrors.Wrap(err) - } - req.Header.Add("Authorization", "Bearer "+client.apiKey) - httpClient := http.Client{} - resp, err := httpClient.Do(req) - if err != nil { - return gerrors.Wrap(err) - } - if resp.StatusCode == 200 { - return nil - } - return gerrors.Newf("/instance-operations/terminate returned non-200 status code: %s", resp.Status) -} - -const LAMBDA_CONFIG_PATH = "/home/ubuntu/.dstack/config.json" - -type LambdaConfig struct { - InstanceID string `json:"instance_id"` - ApiKey string `json:"api_key"` -} - -type LambdaBackend struct { - apiClient *LambdaAPIClient - config LambdaConfig -} - -func init() { - register("lambda", NewLambdaBackend) -} - -func NewLambdaBackend(ctx context.Context) (Backend, error) { - config := LambdaConfig{} - fileContent, err := os.ReadFile(LAMBDA_CONFIG_PATH) - if err != nil { - return nil, gerrors.Wrap(err) - } - err = json.Unmarshal(fileContent, &config) - if err != nil { - return nil, gerrors.Wrap(err) - } - return &LambdaBackend{ - apiClient: NewLambdaAPIClient(config.ApiKey), - config: config, - }, nil -} - -func (b *LambdaBackend) Terminate(ctx context.Context) error { - return gerrors.Wrap(b.apiClient.TerminateInstance(ctx, []string{b.config.InstanceID})) -} diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index f3b6a4822..b70170cb3 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -19,88 +19,137 @@ import ( docker "github.com/docker/docker/client" "github.com/docker/go-connections/nat" "github.com/dstackai/dstack/runner/consts" - "github.com/dstackai/dstack/runner/internal/gerrors" + "github.com/ztrue/tracerr" ) -func RunDocker(ctx context.Context, params DockerParameters, serverAPI APIAdapter) error { +type DockerRunner struct { + client *docker.Client + dockerParams DockerParameters + state RunnerStatus +} + +func NewDockerRunner(dockerParams DockerParameters) (*DockerRunner, error) { client, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation()) if err != nil { - return err + return nil, tracerr.Wrap(err) } - log.Println("Waiting for registry auth") - registryAuth := <-serverAPI.GetRegistryAuth() - serverAPI.SetState(Pulling) + runner := &DockerRunner{ + client: client, + dockerParams: dockerParams, + state: Pending, + } + return runner, nil +} + +func (d *DockerRunner) Run(ctx context.Context, cfg DockerImageConfig) error { + var err error log.Println("Pulling image") - if err = pullImage(ctx, client, params.DockerImageName(), registryAuth); err != nil { - return gerrors.Wrap(err) + d.state = Pulling + if err = pullImage(ctx, d.client, cfg); err != nil { + d.state = Pending + fmt.Printf("pullImage error: %s\n", err.Error()) + return err } + log.Println("Creating container") - containerID, err := createContainer(ctx, client, params) + d.state = Creating + containerID, err := createContainer(ctx, d.client, d.dockerParams, cfg) if err != nil { - return gerrors.Wrap(err) + d.state = Pending + fmt.Printf("createContainer error: %s\n", err.Error()) + return err } - if !params.DockerKeepContainer() { + + if !d.dockerParams.DockerKeepContainer() { defer func() { log.Println("Deleting container") - _ = client.ContainerRemove(ctx, containerID, types.ContainerRemoveOptions{Force: true}) + err := d.client.ContainerRemove(ctx, containerID, types.ContainerRemoveOptions{Force: true}) + if err != nil { + log.Printf("ContainerRemove error: %s\n", err.Error()) + } }() } - serverAPI.SetState(Running) log.Printf("Running container, id=%s\n", containerID) - if err = runContainer(ctx, client, containerID); err != nil { - return gerrors.Wrap(err) + d.state = Running + if err = runContainer(ctx, d.client, containerID); err != nil { + d.state = Pending + fmt.Printf("runContainer error: %s\n", err.Error()) + return err } - log.Println("Container finished successfully") + + log.Printf("Container finished successfully, id=%s\n", containerID) + + d.state = Pending return nil } -func pullImage(ctx context.Context, client docker.APIClient, imageName string, registryAuth string) error { - if !strings.Contains(imageName, ":") { - imageName += ":latest" +func (d DockerRunner) GetState() RunnerStatus { + return d.state +} + +func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerImageConfig) error { + if !strings.Contains(taskParams.ImageName, ":") { + taskParams.ImageName += ":latest" } images, err := client.ImageList(ctx, types.ImageListOptions{ - Filters: filters.NewArgs(filters.Arg("reference", imageName)), + Filters: filters.NewArgs(filters.Arg("reference", taskParams.ImageName)), }) if err != nil { - return gerrors.Wrap(err) + return tracerr.Wrap(err) } - if len(images) > 0 { + + // TODO: force pull latset + if len(images) > 0 && !strings.Contains(taskParams.ImageName, ":latest") { return nil } - reader, err := client.ImagePull(ctx, imageName, types.ImagePullOptions{RegistryAuth: registryAuth}) // todo test registry auth + opts := types.ImagePullOptions{} + regAuth, _ := taskParams.EncodeRegistryAuth() + if regAuth != "" { + opts.RegistryAuth = regAuth + } + + reader, err := client.ImagePull(ctx, taskParams.ImageName, opts) // todo test registry auth if err != nil { - return gerrors.Wrap(err) + return tracerr.Wrap(err) } defer func() { _ = reader.Close() }() - _, err = io.ReadAll(reader) - return gerrors.Wrap(err) + _, err = io.Copy(io.Discard, reader) + if err != nil { + return tracerr.Wrap(err) + } + + // {"status":"Pulling from clickhouse/clickhouse-server","id":"latest"} + // {"status":"Digest: sha256:2ff5796c67e8d588273a5f3f84184b9cdaa39a324bcf74abd3652d818d755f8c"} + // {"status":"Status: Downloaded newer image for clickhouse/clickhouse-server:latest"} + + return nil } -func createContainer(ctx context.Context, client docker.APIClient, params DockerParameters) (string, error) { +func createContainer(ctx context.Context, client docker.APIClient, dockerParams DockerParameters, taskParams DockerImageConfig) (string, error) { runtime, err := getRuntime(ctx, client) if err != nil { - return "", gerrors.Wrap(err) + return "", tracerr.Wrap(err) } - mounts, err := params.DockerMounts() + mounts, err := dockerParams.DockerMounts() if err != nil { - return "", gerrors.Wrap(err) + return "", tracerr.Wrap(err) } containerConfig := &container.Config{ - Image: params.DockerImageName(), - Cmd: []string{strings.Join(params.DockerShellCommands(), " && ")}, + Image: taskParams.ImageName, + Cmd: []string{strings.Join(dockerParams.DockerShellCommands(), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, - ExposedPorts: exposePorts(params.DockerPorts()...), + ExposedPorts: exposePorts(dockerParams.DockerPorts()...), } hostConfig := &container.HostConfig{ NetworkMode: getNetworkMode(), - PortBindings: bindPorts(params.DockerPorts()...), + PortBindings: bindPorts(dockerParams.DockerPorts()...), PublishAllPorts: true, Sysctls: map[string]string{}, Runtime: runtime, @@ -108,20 +157,20 @@ func createContainer(ctx context.Context, client docker.APIClient, params Docker } resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, "") if err != nil { - return "", gerrors.Wrap(err) + return "", tracerr.Wrap(err) } return resp.ID, nil } func runContainer(ctx context.Context, client docker.APIClient, containerID string) error { if err := client.ContainerStart(ctx, containerID, types.ContainerStartOptions{}); err != nil { - return gerrors.Wrap(err) + return tracerr.Wrap(err) } waitCh, errorCh := client.ContainerWait(ctx, containerID, "") select { case <-waitCh: case err := <-errorCh: - return gerrors.Wrap(err) + return tracerr.Wrap(err) } return nil } @@ -181,7 +230,7 @@ func getNetworkMode() container.NetworkMode { func getRuntime(ctx context.Context, client docker.APIClient) (string, error) { info, err := client.Info(ctx) if err != nil { - return "", gerrors.Wrap(err) + return "", tracerr.Wrap(err) } for name := range info.Runtimes { if name == consts.NVIDIA_RUNTIME { @@ -193,24 +242,20 @@ func getRuntime(ctx context.Context, client docker.APIClient) (string, error) { /* DockerParameters interface implementation for CLIArgs */ -func (c *CLIArgs) DockerImageName() string { - return c.Docker.ImageName -} - -func (c *CLIArgs) DockerKeepContainer() bool { +func (c CLIArgs) DockerKeepContainer() bool { return c.Docker.KeepContainer } -func (c *CLIArgs) DockerShellCommands() []string { +func (c CLIArgs) DockerShellCommands() []string { commands := getSSHShellCommands(c.Docker.SSHPort, c.Docker.PublicSSHKey) commands = append(commands, fmt.Sprintf("%s %s", DstackRunnerBinaryName, strings.Join(c.getRunnerArgs(), " "))) return commands } -func (c *CLIArgs) DockerMounts() ([]mount.Mount, error) { +func (c CLIArgs) DockerMounts() ([]mount.Mount, error) { runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", time.Now().Format("20060102-150405")) if err := os.MkdirAll(runnerTemp, 0755); err != nil { - return nil, gerrors.Wrap(err) + return nil, tracerr.Wrap(err) } return []mount.Mount{ @@ -227,6 +272,6 @@ func (c *CLIArgs) DockerMounts() ([]mount.Mount, error) { }, nil } -func (c *CLIArgs) DockerPorts() []int { +func (c CLIArgs) DockerPorts() []int { return []int{c.Runner.HTTPPort, c.Docker.SSHPort} } diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index bf3cee691..60f29a3c9 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -31,7 +31,8 @@ func TestDocker_SSHServer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) defer cancel() - assert.NoError(t, RunDocker(ctx, params, &apiAdapterMock{})) + dockerRunner, _ := NewDockerRunner(params) + assert.NoError(t, dockerRunner.Run(ctx, DockerImageConfig{ImageName: "ubuntu"})) } // TestDocker_SSHServerConnect pulls ubuntu image (without sshd), installs openssh-server and tries to connect via SSH @@ -56,11 +57,13 @@ func TestDocker_SSHServerConnect(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) defer cancel() + dockerRunner, _ := NewDockerRunner(params) + var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - assert.NoError(t, RunDocker(ctx, params, &apiAdapterMock{})) + assert.NoError(t, dockerRunner.Run(ctx, DockerImageConfig{ImageName: "ubuntu"})) }() for i := 0; i < timeout; i++ { @@ -89,10 +92,6 @@ type dockerParametersMock struct { publicSSHKey string } -func (c *dockerParametersMock) DockerImageName() string { - return "ubuntu" -} - func (c *dockerParametersMock) DockerKeepContainer() bool { return false } @@ -114,16 +113,6 @@ func (c *dockerParametersMock) DockerMounts() ([]mount.Mount, error) { return nil, nil } -type apiAdapterMock struct{} - -func (s *apiAdapterMock) GetRegistryAuth() <-chan string { - ch := make(chan string) - close(ch) - return ch -} - -func (s *apiAdapterMock) SetState(string) {} - /* Utilities */ var portNumber int32 = 10000 diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 0a88847c7..913587e52 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -1,16 +1,15 @@ package shim import ( + "encoding/base64" + "encoding/json" + "log" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/registry" ) -type APIAdapter interface { - GetRegistryAuth() <-chan string - SetState(string) -} - type DockerParameters interface { - DockerImageName() string DockerKeepContainer() bool DockerShellCommands() []string DockerMounts() ([]mount.Mount, error) @@ -35,10 +34,33 @@ type CLIArgs struct { } Docker struct { - SSHPort int - RegistryAuthRequired bool - ImageName string - KeepContainer bool - PublicSSHKey string + SSHPort int + KeepContainer bool + PublicSSHKey string } } + +type DockerImageConfig struct { + Username string + Password string + ImageName string +} + +func (ra DockerImageConfig) EncodeRegistryAuth() (string, error) { + if ra.Username == "" && ra.Password == "" { + return "", nil + } + + authConfig := registry.AuthConfig{ + Username: ra.Username, + Password: ra.Password, + } + + encodedConfig, err := json.Marshal(authConfig) + if err != nil { + log.Println("Failed to encode auth config", "err", err) + return "", err + } + + return base64.URLEncoding.EncodeToString(encodedConfig), nil +} diff --git a/runner/internal/shim/runner.go b/runner/internal/shim/runner.go index 3cbf1ce64..1388c4b85 100644 --- a/runner/internal/shim/runner.go +++ b/runner/internal/shim/runner.go @@ -1,14 +1,15 @@ package shim import ( + "context" "fmt" "io" "log" "net/http" "os" - rt "runtime" "strconv" "strings" + "time" "github.com/dstackai/dstack/runner/internal/gerrors" ) @@ -27,16 +28,17 @@ func (c *CLIArgs) GetDockerCommands() []string { } } -func (c *CLIArgs) Download(osName string) error { - tempFile, err := os.CreateTemp("", "dstack-runner") +func (c *CLIArgs) DownloadRunner() error { + url := makeDownloadRunnerUrl(c.Runner.Version, c.Runner.DevChannel) + + runnerBinaryPath, err := downloadRunner(url) if err != nil { return gerrors.Wrap(err) } - if err = tempFile.Close(); err != nil { - return gerrors.Wrap(err) - } - c.Runner.BinaryPath = tempFile.Name() - return gerrors.Wrap(downloadRunner(c.Runner.Version, c.Runner.DevChannel, osName, c.Runner.BinaryPath)) + + c.Runner.BinaryPath = runnerBinaryPath + + return nil } func (c *CLIArgs) getRunnerArgs() []string { @@ -50,42 +52,62 @@ func (c *CLIArgs) getRunnerArgs() []string { } } -func downloadRunner(runnerVersion string, useDev bool, osName string, path string) error { - // darwin-amd64 - // darwin-arm64 - // linux-386 - // linux-amd64 - archName := rt.GOARCH - if osName == "linux" && archName == "arm64" { - archName = "amd64" - } - var url string - if useDev { - url = fmt.Sprintf(DstackRunnerURL, DstackStagingBucket, runnerVersion, osName, archName) - } else { - url = fmt.Sprintf(DstackRunnerURL, DstackReleaseBucket, runnerVersion, osName, archName) +func makeDownloadRunnerUrl(version string, staging bool) string { + bucket := DstackReleaseBucket + if staging { + bucket = DstackStagingBucket } - file, err := os.Create(path) + osName := "linux" + archName := "amd64" + + url := fmt.Sprintf(DstackRunnerURL, bucket, version, osName, archName) + return url +} + +func downloadRunner(url string) (string, error) { + tempFile, err := os.CreateTemp("", "dstack-runner") if err != nil { - return gerrors.Wrap(err) + return "", gerrors.Wrap(err) } - defer func() { _ = file.Close() }() + defer func() { + err := tempFile.Close() + if err != nil { + log.Printf("close file error: %s\n", err) + } + }() log.Printf("Downloading runner from %s\n", url) - resp, err := http.Get(url) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*600) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return gerrors.Wrap(err) + return "", gerrors.Wrap(err) } - defer func() { _ = resp.Body.Close() }() + + resp, err := http.DefaultClient.Do(req) + + if err != nil { + return "", gerrors.Wrap(err) + } + defer func() { + err := resp.Body.Close() + log.Printf("close body error: %s\n", err) + }() + if resp.StatusCode != http.StatusOK { - return gerrors.Newf("unexpected status code: %s", resp.Status) + return "", gerrors.Newf("unexpected status code: %s", resp.Status) } - _, err = io.Copy(file, resp.Body) + _, err = io.Copy(tempFile, resp.Body) if err != nil { - return gerrors.Wrap(err) + return "", gerrors.Wrap(err) + } + + if err := tempFile.Chmod(0755); err != nil { + return "", gerrors.Wrap(err) } - return gerrors.Wrap(file.Chmod(0755)) + return tempFile.Name(), nil } diff --git a/runner/internal/shim/states.go b/runner/internal/shim/states.go index eedc05c85..e12f66041 100644 --- a/runner/internal/shim/states.go +++ b/runner/internal/shim/states.go @@ -1,7 +1,10 @@ package shim +type RunnerStatus string + const ( - WaitRegistryAuth = "waiting_for_registry_auth" - Pulling = "pulling" - Running = "running" + Pending RunnerStatus = "pending" + Pulling RunnerStatus = "pulling" + Creating RunnerStatus = "creating" + Running RunnerStatus = "running" ) diff --git a/runner/internal/shim/subprocess.go b/runner/internal/shim/subprocess.go deleted file mode 100644 index 38268d371..000000000 --- a/runner/internal/shim/subprocess.go +++ /dev/null @@ -1,28 +0,0 @@ -package shim - -import ( - "github.com/dstackai/dstack/runner/internal/gerrors" - "os" - "path/filepath" - rt "runtime" -) - -func RunSubprocess(httpPort int, logLevel int, runnerVersion string, useDev bool) error { - userHomeDir, err := os.UserHomeDir() - if err != nil { - return gerrors.Wrap(err) - } - runnerPath := filepath.Join(userHomeDir, ".dstack/dstack-runner") - if err = os.MkdirAll(filepath.Dir(runnerPath), 0755); err != nil { - return gerrors.Wrap(err) - } - - err = downloadRunner(runnerVersion, useDev, rt.GOOS, runnerPath) - if err != nil { - return gerrors.Wrap(err) - } - // todo create temporary, home and working dirs - // todo start runner - // todo wait till runner completes - return nil -} diff --git a/src/dstack/_internal/cli/commands/config.py b/src/dstack/_internal/cli/commands/config.py index c251baa54..6c98b2107 100644 --- a/src/dstack/_internal/cli/commands/config.py +++ b/src/dstack/_internal/cli/commands/config.py @@ -4,7 +4,7 @@ import dstack.api.server from dstack._internal.cli.commands import BaseCommand -from dstack._internal.cli.utils.common import colors, confirm_ask, console +from dstack._internal.cli.utils.common import confirm_ask, console from dstack._internal.core.errors import CLIError from dstack._internal.core.services.configs import ConfigManager @@ -79,6 +79,4 @@ def _command(self, args: argparse.Namespace): name=args.project, url=args.url, token=args.token, default=set_it_as_default ) config_manager.save() - console.print( - f"Configuration updated at [{colors['code']}]{config_manager.config_filepath}[/{colors['code']}]" - ) + console.print(f"Configuration updated at [code]{config_manager.config_filepath}[/]") diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py new file mode 100644 index 000000000..b815a8a37 --- /dev/null +++ b/src/dstack/_internal/cli/commands/pool.py @@ -0,0 +1,443 @@ +import argparse +from pathlib import Path +from typing import Sequence + +from rich.table import Table + +from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.args import cpu_spec, disk_spec, gpu_spec, memory_spec +from dstack._internal.cli.services.configurators.profile import ( + apply_profile_args, + register_profile_args, +) +from dstack._internal.cli.utils.common import confirm_ask, console +from dstack._internal.core.errors import CLIError, ServerClientError +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, + SSHKey, +) +from dstack._internal.core.models.pools import Instance, Pool +from dstack._internal.core.models.profiles import ( + DEFAULT_TERMINATION_IDLE_TIME, + Profile, + SpotPolicy, + TerminationPolicy, +) +from dstack._internal.core.models.resources import DEFAULT_CPU_COUNT, DEFAULT_MEMORY_SIZE +from dstack._internal.core.models.runs import InstanceStatus, Requirements +from dstack._internal.utils.common import pretty_date +from dstack._internal.utils.logging import get_logger +from dstack.api._public.resources import Resources +from dstack.api.utils import load_profile + +logger = get_logger(__name__) + + +class PoolCommand(APIBaseCommand): + NAME = "pool" + DESCRIPTION = "Pool management" + + def _register(self) -> None: + super()._register() + self._parser.set_defaults(subfunc=self._list) + + subparsers = self._parser.add_subparsers(dest="action") + + # list pools + list_parser = subparsers.add_parser( + "list", + help="List pools", + description="List available pools", + formatter_class=self._parser.formatter_class, + ) + list_parser.add_argument("-v", "--verbose", help="Show more information") + list_parser.set_defaults(subfunc=self._list) + + # create pool + create_parser = subparsers.add_parser( + "create", help="Create pool", formatter_class=self._parser.formatter_class + ) + create_parser.add_argument( + "-n", "--name", dest="pool_name", help="The name of the pool", required=True + ) + create_parser.set_defaults(subfunc=self._create) + + # delete pool + delete_parser = subparsers.add_parser( + "delete", help="Delete pool", formatter_class=self._parser.formatter_class + ) + delete_parser.add_argument( + "-n", "--name", dest="pool_name", help="The name of the pool", required=True + ) + delete_parser.add_argument( + "-f", "--force", dest="force", help="Force remove", type=bool, default=False + ) + delete_parser.set_defaults(subfunc=self._delete) + + # show pool instances + show_parser = subparsers.add_parser( + "show", + help="Show pool instances", + description="Show instances in the pool", + formatter_class=self._parser.formatter_class, + ) + show_parser.add_argument( + "--pool", + dest="pool_name", + help="The name of the pool. If not set, the default pool will be used", + ) + show_parser.set_defaults(subfunc=self._show) + + # add instance + add_parser = subparsers.add_parser( + "add", help="Add instance to pool", formatter_class=self._parser.formatter_class + ) + add_parser.add_argument( + "--pool", + dest="pool_name", + help="The name of the pool. If not set, the default pool will be used", + ) + add_parser.add_argument( + "-y", "--yes", help="Don't ask for confirmation", action="store_true" + ) + add_parser.add_argument( + "--remote", + help="Add remote runner as an instance", + dest="remote", + action="store_true", + default=False, + ) + add_parser.add_argument("--remote-host", help="Remote runner host", dest="remote_host") + add_parser.add_argument( + "--remote-port", help="Remote runner port", dest="remote_port", default=10999 + ) + add_parser.add_argument("--name", dest="instance_name", help="The name of the instance") + register_profile_args(add_parser) + register_resource_args(add_parser) + add_parser.set_defaults(subfunc=self._add) + + # remove instance + remove_parser = subparsers.add_parser( + "remove", + help="Remove instance from the pool", + formatter_class=self._parser.formatter_class, + ) + remove_parser.add_argument( + "instance_name", + help="The name of the instance", + ) + remove_parser.add_argument( + "--pool", + dest="pool_name", + help="The name of the pool. If not set, the default pool will be used", + ) + remove_parser.add_argument( + "--force", + action="store_true", + help="The name of the instance", + ) + remove_parser.add_argument( + "-y", "--yes", help="Don't ask for confirmation", action="store_true" + ) + remove_parser.set_defaults(subfunc=self._remove) + + # pool set-default + set_default_parser = subparsers.add_parser( + "set-default", + help="Set the project's default pool", + formatter_class=self._parser.formatter_class, + ) + set_default_parser.add_argument( + "--pool", dest="pool_name", help="The name of the pool", required=True + ) + set_default_parser.set_defaults(subfunc=self._set_default) + + def _list(self, args: argparse.Namespace) -> None: + pools = self.api.client.pool.list(self.api.project) + print_pool_table(pools, verbose=getattr(args, "verbose", False)) + + def _create(self, args: argparse.Namespace) -> None: + self.api.client.pool.create(self.api.project, args.pool_name) + console.print(f"Pool {args.pool_name!r} created") + + def _delete(self, args: argparse.Namespace) -> None: + # TODO(egor-s): ask for confirmation + with console.status("Removing pool..."): + self.api.client.pool.delete(self.api.project, args.pool_name, args.force) + console.print(f"Pool {args.pool_name!r} removed") + + def _remove(self, args: argparse.Namespace) -> None: + pool = self.api.client.pool.show(self.api.project, args.pool_name) + pool.instances = [i for i in pool.instances if i.name == args.instance_name] + if not pool.instances: + raise CLIError(f"Instance {args.instance_name!r} not found in pool {pool.name!r}") + + console.print(f" [bold]Pool name[/] {pool.name}\n") + print_instance_table(pool.instances) + + if not args.force and any(i.status == InstanceStatus.BUSY for i in pool.instances): + # TODO(egor-s): implement this logic in the server too + raise CLIError("Can't remove busy instance. Use `--force` to remove anyway") + + if not args.yes and not confirm_ask(f"Remove instance {args.instance_name!r}?"): + console.print("\nExiting...") + return + + with console.status("Removing instance..."): + self.api.client.pool.remove( + self.api.project, pool.name, args.instance_name, args.force + ) + console.print(f"Instance {args.instance_name!r} removed") + + def _set_default(self, args: argparse.Namespace) -> None: + result = self.api.client.pool.set_default(self.api.project, args.pool_name) + if not result: + console.print(f"Failed to set default pool {args.pool_name!r}", style="error") + + def _show(self, args: argparse.Namespace) -> None: + resp = self.api.client.pool.show(self.api.project, args.pool_name) + console.print(f" [bold]Pool name[/] {resp.name}\n") + print_instance_table(resp.instances) + + def _add(self, args: argparse.Namespace) -> None: + super()._command(args) + + resources = Resources( + cpu=args.cpu, + memory=args.memory, + gpu=args.gpu, + shm_size=args.shared_memory, + disk=args.disk, + ) + requirements = Requirements( + resources=resources, + max_price=args.max_price, + spot=(args.spot_policy == SpotPolicy.SPOT), # TODO(egor-s): None if SpotPolicy.AUTO + ) + + profile = load_profile(Path.cwd(), args.profile) + apply_profile_args(args, profile) + profile.pool_name = args.pool_name + + termination_policy_idle = DEFAULT_TERMINATION_IDLE_TIME + termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + profile.termination_idle_time = termination_policy_idle + profile.termination_policy = termination_policy + + # Add remote instance + if args.remote: + result = self.api.client.pool.add_remote( + self.api.project, + resources, + profile, + args.instance_name, + args.remote_host, + args.remote_port, + ) + if not result: + console.print(f"[error]Failed to add remote instance {args.instance_name!r}[/]") + # TODO(egor-s): print on success + return + + with console.status("Getting instances..."): + pool_name, offers = self.api.runs.get_offers(profile, requirements) + + print_offers_table(pool_name, profile, requirements, offers) + if not args.yes and not confirm_ask("Continue?"): + console.print("\nExiting...") + return + + # TODO(egor-s): user pub key must be added during the `run`, not `pool add` + user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip() + pub_key = SSHKey(public=user_pub_key) + try: + with console.status("Creating instance..."): + instance = self.api.runs.create_instance(pool_name, profile, requirements, pub_key) + except ServerClientError as e: + raise CLIError(e.msg) + print_instance_table([instance]) + + def _command(self, args: argparse.Namespace) -> None: + super()._command(args) + # TODO handle 404 and other errors + args.subfunc(args) + + +def print_pool_table(pools: Sequence[Pool], verbose: bool) -> None: + table = Table(box=None) + table.add_column("NAME") + table.add_column("DEFAULT") + table.add_column("INSTANCES") + if verbose: + table.add_column("CREATED") + + sorted_pools = sorted(pools, key=lambda r: r.name) + for pool in sorted_pools: + default_mark = "default" if pool.default else "" + style = "success" if pool.total_instances == pool.available_instances else "error" + health = f"[{style}]{pool.available_instances}/{pool.total_instances}[/]" + row = [pool.name, default_mark, health] + if verbose: + row.append(pretty_date(pool.created_at)) + table.add_row(*row) + + console.print(table) + console.print() + + +def print_instance_table(instances: Sequence[Instance]) -> None: + table = Table(box=None) + table.add_column("INSTANCE NAME") + table.add_column("BACKEND") + table.add_column("INSTANCE TYPE") + table.add_column("STATUS") + table.add_column("PRICE") + + for instance in instances: + style = "success" if instance.status.is_available() else "warning" + row = [ + instance.name, + instance.backend, + instance.instance_type.resources.pretty_format(), + f"[{style}]{instance.status.value}[/]", + f"${instance.price:.4}", + ] + table.add_row(*row) + + console.print(table) + console.print() + + +def print_offers_table( + pool_name: str, + profile: Profile, + requirements: Requirements, + instance_offers: Sequence[InstanceOfferWithAvailability], + offers_limit: int = 3, +) -> None: + pretty_req = requirements.pretty_format(resources_only=True) + max_price = f"${requirements.max_price:g}" if requirements.max_price else "-" + max_duration = ( + f"{profile.max_duration / 3600:g}h" if isinstance(profile.max_duration, int) else "-" + ) + + # TODO: improve retry policy + # retry_policy = profile.retry_policy + # retry_policy = ( + # (f"{retry_policy.limit / 3600:g}h" if retry_policy.limit else "yes") + # if retry_policy.retry + # else "no" + # ) + + # TODO: improve spot policy + if requirements.spot is None: + spot_policy = "auto" + elif requirements.spot: + spot_policy = "spot" + else: + spot_policy = "on-demand" + + def th(s: str) -> str: + return f"[bold]{s}[/bold]" + + props = Table(box=None, show_header=False) + props.add_column(no_wrap=True) # key + props.add_column() # value + + props.add_row(th("Pool name"), pool_name) + props.add_row(th("Min resources"), pretty_req) + props.add_row(th("Max price"), max_price) + props.add_row(th("Max duration"), max_duration) + props.add_row(th("Spot policy"), spot_policy) + # props.add_row(th("Retry policy"), retry_policy) + + offers_table = Table(box=None) + offers_table.add_column("#") + offers_table.add_column("BACKEND") + offers_table.add_column("REGION") + offers_table.add_column("INSTANCE") + offers_table.add_column("RESOURCES") + offers_table.add_column("SPOT") + offers_table.add_column("PRICE") + offers_table.add_column() + + print_offers = instance_offers[:offers_limit] + + for i, offer in enumerate(print_offers, start=1): + r = offer.instance.resources + + availability = "" + if offer.availability in { + InstanceAvailability.NOT_AVAILABLE, + InstanceAvailability.NO_QUOTA, + }: + availability = offer.availability.value.replace("_", " ").title() + offers_table.add_row( + f"{i}", + offer.backend, + offer.region, + offer.instance.name, + r.pretty_format(), + "yes" if r.spot else "no", + f"${offer.price:g}", + availability, + style=None if i == 1 else "secondary", + ) + if len(print_offers) > offers_limit: + offers_table.add_row("", "...", style="secondary") + + console.print(props) + console.print() + if len(print_offers) > 0: + console.print(offers_table) + console.print() + + +def register_resource_args(parser: argparse.ArgumentParser) -> None: + resources_group = parser.add_argument_group("Resources") + resources_group.add_argument( + "--cpu", + help=f"Request the CPU count. Default: {DEFAULT_CPU_COUNT}", + dest="cpu", + metavar="SPEC", + default=DEFAULT_CPU_COUNT, + type=cpu_spec, + ) + + resources_group.add_argument( + "--memory", + help="Request the size of RAM. " + f"The format is [code]SIZE[/]:[code]MB|GB|TB[/]. Default: {DEFAULT_MEMORY_SIZE}", + dest="memory", + metavar="SIZE", + default=DEFAULT_MEMORY_SIZE, + type=memory_spec, + ) + + resources_group.add_argument( + "--shared-memory", + help="Request the size of Shared Memory. The format is [code]SIZE[/]:[code]MB|GB|TB[/].", + dest="shared_memory", + default=None, + metavar="SIZE", + ) + + resources_group.add_argument( + "--gpu", + help="Request GPU for the run. " + "The format is [code]NAME[/]:[code]COUNT[/]:[code]MEMORY[/] (all parts are optional)", + dest="gpu", + default=None, + metavar="SPEC", + type=gpu_spec, + ) + + resources_group.add_argument( + "--disk", + help="Request the size of disk for the run. Example [code]--disk 100GB..[/].", + dest="disk", + metavar="SIZE", + default=None, + type=disk_spec, + ) diff --git a/src/dstack/_internal/cli/commands/run.py b/src/dstack/_internal/cli/commands/run.py index f39296da4..4ee276fc1 100644 --- a/src/dstack/_internal/cli/commands/run.py +++ b/src/dstack/_internal/cli/commands/run.py @@ -17,8 +17,14 @@ from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.errors import CLIError, ConfigurationError, ServerClientError from dstack._internal.core.models.configurations import ConfigurationType +from dstack._internal.core.models.profiles import ( + DEFAULT_TERMINATION_IDLE_TIME, + CreationPolicy, + TerminationPolicy, +) from dstack._internal.core.models.runs import JobErrorCode from dstack._internal.core.services.configs import ConfigManager +from dstack._internal.utils.common import parse_pretty_duration from dstack._internal.utils.logging import get_logger from dstack.api import RunStatus from dstack.api._public.runs import Run @@ -78,6 +84,29 @@ def _register(self): type=int, default=3, ) + self._parser.add_argument( + "--pool", + dest="pool_name", + help="The name of the pool. If not set, the default pool will be used", + ) + self._parser.add_argument( + "--reuse", + dest="creation_policy_reuse", + action="store_true", + help="Reuse instance from pool", + ) + self._parser.add_argument( + "--idle-duration", + dest="idle_duration", + type=str, + help="Idle time before instance termination", + ) + self._parser.add_argument( + "--instance", + dest="instance_name", + metavar="NAME", + help="Reuse instance from pool with name [code]NAME[/]", + ) register_profile_args(self._parser) def _command(self, args: argparse.Namespace): @@ -89,6 +118,31 @@ def _command(self, args: argparse.Namespace): self._parser.print_help() return + termination_policy_idle = DEFAULT_TERMINATION_IDLE_TIME + termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + + if args.idle_duration is not None: + try: + termination_policy_idle = int(args.idle_duration) + except ValueError: + termination_policy_idle = 60 * parse_pretty_duration(args.idle_duration) + + creation_policy = ( + CreationPolicy.REUSE if args.creation_policy_reuse else CreationPolicy.REUSE_OR_CREATE + ) + + if creation_policy == CreationPolicy.REUSE and termination_policy_idle is not None: + console.print( + "[warning]If the flag --reuse is set, the argument --idle-duration will be skipped[/]" + ) + termination_policy = TerminationPolicy.DONT_DESTROY + + if args.instance_name is not None and termination_policy_idle is not None: + console.print( + f"[warning]--idle-duration won't be applied to the instance {args.instance_name!r}[/]" + ) + termination_policy = TerminationPolicy.DONT_DESTROY + super()._command(args) try: repo = self.api.repos.load(Path.cwd()) @@ -121,6 +175,11 @@ def _command(self, args: argparse.Namespace): max_price=profile.max_price, working_dir=args.working_dir, run_name=args.run_name, + pool_name=args.pool_name, + instance_name=args.instance_name, + creation_policy=creation_policy, + termination_policy=termination_policy, + termination_policy_idle=termination_policy_idle, ) except ConfigurationError as e: raise CLIError(str(e)) @@ -179,10 +238,16 @@ def _command(self, args: argparse.Namespace): else: console.print("[error]Failed to attach, exiting...[/]") - run.refresh() - if run.status.is_finished(): - _print_fail_message(run) - abort_at_exit = False + # After reading the logs, the run may not be marked as finished immediately. + # Give the run some time to transit into a finished state before aborting it. + for _ in range(5): + run.refresh() + if run.status.is_finished(): + if run.status == RunStatus.FAILED: + _print_fail_message(run) + abort_at_exit = False + break + time.sleep(1) except KeyboardInterrupt: try: if not confirm_ask("\nStop the run before detaching?"): diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index a714d3a75..a6afa33a7 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -6,11 +6,12 @@ from dstack._internal.cli.commands.gateway import GatewayCommand from dstack._internal.cli.commands.init import InitCommand from dstack._internal.cli.commands.logs import LogsCommand +from dstack._internal.cli.commands.pool import PoolCommand from dstack._internal.cli.commands.ps import PsCommand from dstack._internal.cli.commands.run import RunCommand from dstack._internal.cli.commands.server import ServerCommand from dstack._internal.cli.commands.stop import StopCommand -from dstack._internal.cli.utils.common import colors, console +from dstack._internal.cli.utils.common import _colors, console from dstack._internal.cli.utils.updates import check_for_updates from dstack._internal.core.errors import ClientError, CLIError from dstack._internal.utils.logging import get_logger @@ -21,8 +22,8 @@ def main(): RichHelpFormatter.usage_markup = True - RichHelpFormatter.styles["code"] = colors["code"] - RichHelpFormatter.styles["argparse.args"] = colors["code"] + RichHelpFormatter.styles["code"] = _colors["code"] + RichHelpFormatter.styles["argparse.args"] = _colors["code"] RichHelpFormatter.styles["argparse.groups"] = "bold grey74" RichHelpFormatter.styles["argparse.text"] = "grey74" @@ -50,6 +51,7 @@ def main(): subparsers = parser.add_subparsers(metavar="COMMAND") ConfigCommand.register(subparsers) GatewayCommand.register(subparsers) + PoolCommand.register(subparsers) InitCommand.register(subparsers) LogsCommand.register(subparsers) PsCommand.register(subparsers) diff --git a/src/dstack/_internal/cli/services/args.py b/src/dstack/_internal/cli/services/args.py new file mode 100644 index 000000000..24a4663dc --- /dev/null +++ b/src/dstack/_internal/cli/services/args.py @@ -0,0 +1,35 @@ +import re +from typing import Dict, Tuple + +from pydantic import parse_obj_as + +from dstack._internal.core.models import resources as resources +from dstack._internal.core.models.configurations import PortMapping + + +def gpu_spec(v: str) -> Dict: + return resources.GPUSpec.parse(v) + + +def env_var(v: str) -> Tuple[str, str]: + r = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)=(.*)$", v) + if r is None: + raise ValueError(v) + key, value = r.groups() + return key, value + + +def port_mapping(v: str) -> PortMapping: + return PortMapping.parse(v) + + +def cpu_spec(v: str) -> resources.Range[int]: + return parse_obj_as(resources.Range[int], v) + + +def memory_spec(v: str) -> resources.Range[resources.Memory]: + return parse_obj_as(resources.Range[resources.Memory], v) + + +def disk_spec(v: str) -> resources.DiskSpec: + return parse_obj_as(resources.DiskSpec, v) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 155e85cce..75b26aa44 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -1,11 +1,9 @@ import argparse -import re import subprocess -from typing import Dict, List, Optional, Tuple, Type - -from pydantic import parse_obj_as +from typing import Dict, List, Optional, Type import dstack._internal.core.models.resources as resources +from dstack._internal.cli.services.args import disk_spec, env_var, gpu_spec, port_mapping from dstack._internal.cli.utils.common import console from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.configurations import ( @@ -131,26 +129,6 @@ def apply(cls, args: argparse.Namespace, unknown: List[str], conf: ServiceConfig cls.interpolate_run_args(conf.commands, unknown) -def env_var(v: str) -> Tuple[str, str]: - r = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)=(.*)$", v) - if r is None: - raise ValueError(v) - key, value = r.groups() - return key, value - - -def gpu_spec(v: str) -> Dict: - return resources.GPUSpec.parse(v) - - -def disk_spec(v: str) -> resources.DiskSpec: - return parse_obj_as(resources.DiskSpec, v) - - -def port_mapping(v: str) -> PortMapping: - return PortMapping.parse(v) - - def merge_ports(conf: List[PortMapping], args: List[PortMapping]) -> Dict[int, PortMapping]: unique_ports_constraint([pm.container_port for pm in conf]) unique_ports_constraint([pm.container_port for pm in args]) diff --git a/src/dstack/_internal/cli/utils/common.py b/src/dstack/_internal/cli/utils/common.py index 5bec4dd17..9eb1dba98 100644 --- a/src/dstack/_internal/cli/utils/common.py +++ b/src/dstack/_internal/cli/utils/common.py @@ -11,7 +11,7 @@ from dstack._internal.core.errors import CLIError, DstackError -colors = { +_colors = { "secondary": "grey58", "success": "green", "warning": "yellow", @@ -19,7 +19,7 @@ "code": "bold sea_green3", } -console = Console(theme=Theme(colors)) +console = Console(theme=Theme(_colors)) def cli_error(e: DstackError) -> CLIError: diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index c2509be70..b3c27191e 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -42,6 +42,7 @@ def th(s: str) -> str: props.add_row(th("Configuration"), run_plan.run_spec.configuration_path) props.add_row(th("Project"), run_plan.project_name) props.add_row(th("User"), run_plan.user) + props.add_row(th("Pool name"), run_plan.run_spec.profile.pool_name) props.add_row(th("Min resources"), pretty_req) props.add_row(th("Max price"), max_price) props.add_row(th("Max duration"), max_duration) @@ -67,6 +68,8 @@ def th(s: str) -> str: if offer.availability in { InstanceAvailability.NOT_AVAILABLE, InstanceAvailability.NO_QUOTA, + InstanceAvailability.READY, + InstanceAvailability.BUSY, }: availability = offer.availability.value.replace("_", " ").title() offers.add_row( @@ -78,7 +81,7 @@ def th(s: str) -> str: "yes" if r.spot else "no", f"${offer.price:g}", availability, - style=None if i == 1 else "grey58", + style=None if i == 1 else "secondary", ) if job_plan.total_offers > len(job_plan.offers): offers.add_row("", "...", style="secondary") diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 3e12bb037..121c6d129 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -20,10 +20,12 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, + InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, LaunchedGatewayInfo, LaunchedInstanceInfo, + SSHKey, ) from dstack._internal.core.models.runs import Job, Requirements, Run from dstack._internal.utils.logging import get_logger @@ -84,7 +86,7 @@ def get_quotas(client: botocore.client.BaseClient) -> Dict[str, int]: def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ): + ) -> None: client = self.session.client("ec2", region_name=region) try: client.terminate_instances(InstanceIds=[instance_id]) @@ -94,24 +96,21 @@ def terminate_instance( else: raise e - def run_job( + def create_instance( self, - run: Run, - job: Job, instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, + instance_config: InstanceConfiguration, ) -> LaunchedInstanceInfo: - project_id = run.project_name + project_name = instance_config.project_name ec2 = self.session.resource("ec2", region_name=instance_offer.region) ec2_client = self.session.client("ec2", region_name=instance_offer.region) iam_client = self.session.client("iam", region_name=instance_offer.region) tags = [ - {"Key": "Name", "Value": get_instance_name(run, job)}, + {"Key": "Name", "Value": instance_config.instance_name}, {"Key": "owner", "Value": "dstack"}, - {"Key": "dstack_project", "Value": project_id}, - {"Key": "dstack_user", "Value": run.user}, + {"Key": "dstack_project", "Value": project_name}, + {"Key": "dstack_user", "Value": instance_config.user}, ] try: subnet_id = None @@ -144,21 +143,13 @@ def run_job( instance_type=instance_offer.instance.name, iam_instance_profile_arn=aws_resources.create_iam_instance_profile( iam_client=iam_client, - project_id=project_id, - ), - user_data=get_user_data( - backend=BackendType.AWS, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ], - registry_auth_required=job.job_spec.registry_auth is not None, + project_id=project_name, ), + user_data=get_user_data(authorized_keys=instance_config.get_public_keys()), tags=tags, security_group_id=aws_resources.create_security_group( ec2_client=ec2_client, - project_id=project_id, + project_id=project_name, vpc_id=vpc_id, ), spot=instance_offer.instance.resources.spot, @@ -187,6 +178,27 @@ def run_job( logger.warning("Got botocore.exceptions.ClientError: %s", e) raise NoCapacityError() + def run_job( + self, + run: Run, + job: Job, + instance_offer: InstanceOfferWithAvailability, + project_ssh_public_key: str, + project_ssh_private_key: str, + ) -> LaunchedInstanceInfo: + instance_config = InstanceConfiguration( + project_name=run.project_name, + instance_name=get_instance_name(run, job), # TODO: generate name + ssh_keys=[ + SSHKey(public=run.run_spec.ssh_key_pub.strip()), + SSHKey(public=project_ssh_public_key.strip()), + ], + job_docker_config=None, + user=run.user, + ) + launched_instance_info = self.create_instance(instance_offer, instance_config) + return launched_instance_info + def create_gateway( self, instance_name: str, @@ -228,7 +240,7 @@ def create_gateway( ) -def _has_quota(quotas: Dict[str, float], instance_name: str) -> bool: +def _has_quota(quotas: Dict[str, int], instance_name: str) -> bool: if instance_name.startswith("p"): return quotas.get("P/OnDemand", 0) > 0 if instance_name.startswith("g"): diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index fdd6edda3..57a90e3cf 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -132,12 +132,7 @@ def run_job( # instance_name includes region because Azure may create an instance resource # even when provisioning fails. instance_name=f"{get_instance_name(run, job)}-{instance_offer.region}", - user_data=get_user_data( - backend=BackendType.AZURE, - image_name=job.job_spec.image_name, - authorized_keys=ssh_pub_keys, - registry_auth_required=job.job_spec.registry_auth is not None, - ), + user_data=get_user_data(authorized_keys=ssh_pub_keys), ssh_pub_keys=ssh_pub_keys, spot=instance_offer.instance.resources.spot, disk_size=disk_size, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index fc8c61574..f82addcac 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -2,15 +2,15 @@ import re from abc import ABC, abstractmethod from functools import lru_cache -from typing import List, Optional +from typing import Any, Dict, List, Optional import git import requests import yaml from dstack._internal import settings -from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( + InstanceConfiguration, InstanceOfferWithAvailability, LaunchedGatewayInfo, LaunchedInstanceInfo, @@ -39,10 +39,17 @@ def run_job( ) -> LaunchedInstanceInfo: pass + def create_instance( + self, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + ) -> LaunchedInstanceInfo: + raise NotImplementedError() + @abstractmethod def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ): + ) -> None: pass def create_gateway( @@ -60,18 +67,10 @@ def get_instance_name(run: Run, job: Job) -> str: def get_user_data( - backend: BackendType, - image_name: str, authorized_keys: List[str], - registry_auth_required: bool, - cloud_config_kwargs: Optional[dict] = None, + cloud_config_kwargs: Optional[Dict[Any, Any]] = None, ) -> str: - commands = get_shim_commands( - backend=backend, - image_name=image_name, - authorized_keys=authorized_keys, - registry_auth_required=registry_auth_required, - ) + commands = get_shim_commands(authorized_keys) return get_cloud_config( runcmd=[["sh", "-c", " && ".join(commands)]], ssh_authorized_keys=authorized_keys, @@ -79,25 +78,18 @@ def get_user_data( ) -def get_shim_commands( - backend: BackendType, - image_name: str, - authorized_keys: List[str], - registry_auth_required: bool, -) -> List[str]: +def get_shim_commands(authorized_keys: List[str]) -> List[str]: build = get_dstack_runner_version() env = { - "DSTACK_BACKEND": backend.value, "DSTACK_RUNNER_LOG_LEVEL": "6", "DSTACK_RUNNER_VERSION": build, - "DSTACK_IMAGE_NAME": image_name, "DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys), "DSTACK_HOME": "/root/.dstack", } commands = get_dstack_shim(build) for k, v in env.items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script(registry_auth_required) + commands += get_run_shim_script() return commands @@ -119,18 +111,17 @@ def get_dstack_shim(build: str) -> List[str]: if settings.DSTACK_VERSION is not None: bucket = "dstack-runner-downloads" + url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" + return [ - f'sudo curl --output /usr/local/bin/dstack-shim "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"', + f'sudo curl --connect-timeout 60 --max-time 240 --retry 1 --output /usr/local/bin/dstack-shim "{url}"', "sudo chmod +x /usr/local/bin/dstack-shim", ] -def get_run_shim_script(registry_auth_required: bool) -> List[str]: +def get_run_shim_script() -> List[str]: dev_flag = "" if settings.DSTACK_VERSION is not None else "--dev" - with_auth_flag = "--with-auth" if registry_auth_required else "" - return [ - f"nohup dstack-shim {dev_flag} docker {with_auth_flag} --keep-container >/root/shim.log 2>&1 &" - ] + return [f"nohup dstack-shim {dev_flag} docker --keep-container >/root/shim.log 2>&1 &"] def get_gateway_user_data(authorized_key: str) -> str: @@ -183,13 +174,18 @@ def get_docker_commands(authorized_keys: List[str]) -> List[str]: # start sshd "/usr/sbin/sshd -p 10022 -o PermitUserEnvironment=yes", ] - build = get_dstack_runner_version() + runner = "/usr/local/bin/dstack-runner" + + build = get_dstack_runner_version() bucket = "dstack-runner-downloads-stgn" if settings.DSTACK_VERSION is not None: bucket = "dstack-runner-downloads" + + url = f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64" + commands += [ - f'curl --output {runner} "https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-runner-linux-amd64"', + f'curl --connect-timeout 60 --max-time 240 --retry 1 --output {runner} "{url}"', f"chmod +x {runner}", f"{runner} --log-level 6 start --http-port 10999 --temp-dir /tmp/runner --home-dir /root --working-dir /workflow", ] diff --git a/src/dstack/_internal/core/backends/base/offers.py b/src/dstack/_internal/core/backends/base/offers.py index f484e6102..89581e619 100644 --- a/src/dstack/_internal/core/backends/base/offers.py +++ b/src/dstack/_internal/core/backends/base/offers.py @@ -27,8 +27,11 @@ def get_catalog_offers( q = requirements_to_query_filter(requirements) q.provider = [provider] offers = [] + catalog = catalog if catalog is not None else gpuhunt.default_catalog() + locs = [] for item in catalog.query(**asdict(q)): + locs.append(item.location) if locations is not None and item.location not in locations: continue offer = catalog_item_to_offer(backend, item, requirements) diff --git a/src/dstack/_internal/core/backends/datacrunch/api_client.py b/src/dstack/_internal/core/backends/datacrunch/api_client.py index 759da3907..47f4f2f6d 100644 --- a/src/dstack/_internal/core/backends/datacrunch/api_client.py +++ b/src/dstack/_internal/core/backends/datacrunch/api_client.py @@ -5,6 +5,7 @@ from datacrunch.exceptions import APIException from datacrunch.instances.instances import Instance +from dstack._internal.core.errors import NoCapacityError from dstack._internal.utils.ssh import get_public_key_fingerprint @@ -55,6 +56,7 @@ def wait_for_instance(self, instance_id: str) -> Optional[Instance]: if instance is not None and instance.status == "running": return instance time.sleep(WAIT_FOR_INSTANCE_INTERVAL) + return def deploy_instance( self, @@ -67,16 +69,20 @@ def deploy_instance( disk_size, is_spot=True, location="FIN-01", - ): - instance = self.client.instances.create( - instance_type=instance_type, - image=image, - ssh_key_ids=ssh_key_ids, - hostname=hostname, - description=description, - startup_script_id=startup_script_id, - is_spot=is_spot, - location=location, - os_volume={"name": "OS volume", "size": disk_size}, - ) + ) -> Instance: + try: + instance = self.client.instances.create( + instance_type=instance_type, + image=image, + ssh_key_ids=ssh_key_ids, + hostname=hostname, + description=description, + startup_script_id=startup_script_id, + is_spot=is_spot, + location=location, + os_volume={"name": "OS volume", "size": disk_size}, + ) + except APIException: + raise NoCapacityError() + return instance diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 5f935c296..154b6aa89 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -1,7 +1,9 @@ from typing import Dict, List, Optional from dstack._internal.core.backends.base import Compute -from dstack._internal.core.backends.base.compute import get_shim_commands +from dstack._internal.core.backends.base.compute import ( + get_shim_commands, +) from dstack._internal.core.backends.base.offers import get_catalog_offers from dstack._internal.core.backends.datacrunch.api_client import DataCrunchAPIClient from dstack._internal.core.backends.datacrunch.config import DataCrunchConfig @@ -9,11 +11,16 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, + InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, LaunchedInstanceInfo, + SSHKey, ) from dstack._internal.core.models.runs import Job, Requirements, Run +from dstack._internal.utils.logging import get_logger + +logger = get_logger("datacrunch.compute") class DataCrunchCompute(Compute): @@ -57,62 +64,64 @@ def _get_offers_with_availability( return availability_offers - def run_job( + def create_instance( self, - run: Run, - job: Job, instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, + instance_config: InstanceConfiguration, ) -> LaunchedInstanceInfo: + public_keys = instance_config.get_public_keys() ssh_ids = [] - ssh_ids.append( - self.api_client.get_or_create_ssh_key( - name=f"dstack-{job.job_spec.job_name}.key", - public_key=run.run_spec.ssh_key_pub.strip(), + for ssh_public_key in public_keys: + ssh_ids.append( + # datacrunch allows you to use the same name + self.api_client.get_or_create_ssh_key( + name=f"dstack-{instance_config.instance_name}.key", + public_key=ssh_public_key, + ) ) - ) - ssh_ids.append( - self.api_client.get_or_create_ssh_key( - name=f"dstack-{job.job_spec.job_name}.key", - public_key=project_ssh_public_key.strip(), - ) - ) - commands = get_shim_commands( - backend=BackendType.DATACRUNCH, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ], - registry_auth_required=job.job_spec.registry_auth is not None, - ) + commands = get_shim_commands(authorized_keys=public_keys) startup_script = " ".join([" && ".join(commands)]) - script_name = f"dstack-{job.job_spec.job_name}.sh" + script_name = f"dstack-{instance_config.instance_name}.sh" + + logger.debug("startup script:", startup_script) + startup_script_ids = self.api_client.get_or_create_startup_scrpit( name=script_name, script=startup_script ) - name = job.job_spec.job_name - # Id of image "Ubuntu 22.04 + CUDA 12.0 + Docker" # from API https://datacrunch.stoplight.io/docs/datacrunch-public/c46ab45dbc508-get-all-image-types image_name = "2088da25-bb0d-41cc-a191-dccae45d96fd" disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + instance = self.api_client.deploy_instance( instance_type=instance_offer.instance.name, ssh_key_ids=ssh_ids, startup_script_id=startup_script_ids, - hostname=name, - description=name, + hostname=instance_config.instance_name, + description=instance_config.instance_name, image=image_name, disk_size=disk_size, location=instance_offer.region, ) + logger.debug( + "deploy_instance", + { + "instance_type": instance_offer.instance.name, + "ssh_key_ids": ssh_ids, + "startup_script_id": startup_script_ids, + "hostname": instance_config.instance_name, + "description": instance_config.instance_name, + "image": image_name, + "disk_size": disk_size, + "location": instance_offer.region, + }, + ) + running_instance = self.api_client.wait_for_instance(instance.id) if running_instance is None: raise ComputeError(f"Wait instance {instance.id!r} timeout") @@ -130,7 +139,28 @@ def run_job( return launched_instance + def run_job( + self, + run: Run, + job: Job, + instance_offer: InstanceOfferWithAvailability, + project_ssh_public_key: str, + project_ssh_private_key: str, + ) -> LaunchedInstanceInfo: + instance_config = InstanceConfiguration( + project_name=run.project_name, + instance_name=job.job_spec.job_name, # TODO: generate name + ssh_keys=[ + SSHKey(public=run.run_spec.ssh_key_pub.strip()), + SSHKey(public=project_ssh_public_key.strip()), + ], + job_docker_config=None, + user=run.user, + ) + launched_instance_info = self.create_instance(instance_offer, instance_config) + return launched_instance_info + def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ): + ) -> None: self.api_client.delete_instance(instance_id) diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index c3e50bd02..a9955d37a 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -18,12 +18,14 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, + InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, InstanceType, LaunchedGatewayInfo, LaunchedInstanceInfo, Resources, + SSHKey, ) from dstack._internal.core.models.runs import Job, Requirements, Run @@ -44,7 +46,7 @@ def get_offers( requirements=requirements, extra_filter=_supported_instances_and_zones(self.config.regions), ) - quotas = defaultdict(dict) + quotas: Dict[str, Dict[str, float]] = defaultdict(dict) for region in self.regions_client.list(project=self.config.project_id): for quota in region.quotas: quotas[region.name][quota.metric] = quota.limit - quota.usage @@ -70,7 +72,7 @@ def get_offers( def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None - ): + ) -> None: try: self.instances_client.delete( project=self.config.project_id, zone=region, instance=instance_id @@ -78,25 +80,21 @@ def terminate_instance( except google.api_core.exceptions.NotFound: pass - def run_job( + def create_instance( self, - run: Run, - job: Job, instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, + instance_config: InstanceConfiguration, ) -> LaunchedInstanceInfo: - project_id = run.project_name - instance_name = get_instance_name(run, job) + instance_name = instance_config.instance_name + + authorized_keys = instance_config.get_public_keys() + gcp_resources.create_runner_firewall_rules( firewalls_client=self.firewalls_client, project_id=self.config.project_id, ) disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) - authorized_keys = [ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ] + for zone in _get_instance_zones(instance_offer): request = compute_v1.InsertInstanceRequest() request.zone = zone @@ -113,17 +111,12 @@ def run_job( gpus=instance_offer.instance.resources.gpus, ), spot=instance_offer.instance.resources.spot, - user_data=get_user_data( - backend=BackendType.GCP, - image_name=job.job_spec.image_name, - authorized_keys=authorized_keys, - registry_auth_required=job.job_spec.registry_auth is not None, - ), + user_data=get_user_data(authorized_keys), authorized_keys=authorized_keys, labels={ "owner": "dstack", - "dstack_project": project_id, - "dstack_user": run.user, + "dstack_project": instance_config.project_name, + "dstack_user": instance_config.user, }, tags=[gcp_resources.DSTACK_INSTANCE_TAG], instance_name=instance_name, @@ -153,6 +146,27 @@ def run_job( ) raise NoCapacityError() + def run_job( + self, + run: Run, + job: Job, + instance_offer: InstanceOfferWithAvailability, + project_ssh_public_key: str, + project_ssh_private_key: str, + ) -> LaunchedInstanceInfo: + instance_config = InstanceConfiguration( + project_name=run.project_name, + instance_name=get_instance_name(run, job), # TODO: generate name + ssh_keys=[ + SSHKey(public=run.run_spec.ssh_key_pub.strip()), + SSHKey(public=project_ssh_public_key.strip()), + ], + job_docker_config=None, + user=run.user, + ) + launched_instance_info = self.create_instance(instance_offer, instance_config) + return launched_instance_info + def create_gateway( self, instance_name: str, @@ -207,8 +221,6 @@ def create_gateway( def _supported_instances_and_zones( regions: List[str], ) -> Optional[Callable[[InstanceOffer], bool]]: - regions = set(regions) - def _filter(offer: InstanceOffer) -> bool: # strip zone if offer.region[:-2] not in regions: @@ -232,7 +244,7 @@ def _filter(offer: InstanceOffer) -> bool: return _filter -def _has_gpu_quota(quotas: Dict[str, int], resources: Resources) -> bool: +def _has_gpu_quota(quotas: Dict[str, float], resources: Resources) -> bool: if not resources.gpus: return True gpu = resources.gpus[0] diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index acf4c9393..a01ec7837 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -200,13 +200,7 @@ def terminate_instance( if e.status != 404: raise - def create_gateway( - self, - instance_name: str, - ssh_key_pub: str, - region: str, - project_id: str, - ) -> LaunchedGatewayInfo: + def create_gateway(self, instance_name: str, ssh_key_pub: str, region: str, project_id: str): # Gateway creation is currently limited to Kubernetes with Load Balancer support. # If the cluster does not support Load Balancer, the service will be provisioned but # the external IP/hostname will never be allocated. diff --git a/src/dstack/_internal/core/backends/lambdalabs/compute.py b/src/dstack/_internal/core/backends/lambdalabs/compute.py index 29c1e29e5..45d39761f 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/compute.py +++ b/src/dstack/_internal/core/backends/lambdalabs/compute.py @@ -57,13 +57,7 @@ def run_job( project_ssh_private_key: str, ) -> LaunchedInstanceInfo: commands = get_shim_commands( - backend=BackendType.LAMBDA, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ], - registry_auth_required=job.job_spec.registry_auth is not None, + authorized_keys=[run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] ) # shim is asssumed to be run under root launch_command = "sudo sh -c '" + "&& ".join(commands) + "'" diff --git a/src/dstack/_internal/core/backends/local/compute.py b/src/dstack/_internal/core/backends/local/compute.py index 4cee104e1..70b695c97 100644 --- a/src/dstack/_internal/core/backends/local/compute.py +++ b/src/dstack/_internal/core/backends/local/compute.py @@ -1,6 +1,6 @@ from typing import List, Optional -from dstack._internal.core.backends.base.compute import Compute, get_dstack_runner_version +from dstack._internal.core.backends.base.compute import Compute from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -37,6 +37,18 @@ def terminate_instance( ): pass + def create_instance(self, instance_offer, instance_config) -> LaunchedInstanceInfo: + launched_instance = LaunchedInstanceInfo( + instance_id="local", + ip_address="127.0.0.1", + region="", + username="root", + ssh_port=10022, + dockerized=False, + backend_data=None, + ) + return launched_instance + def run_job( self, run: Run, @@ -45,15 +57,6 @@ def run_job( project_ssh_public_key: str, project_ssh_private_key: str, ) -> LaunchedInstanceInfo: - authorized_keys = f"{run.run_spec.ssh_key_pub.strip()}\\n{project_ssh_public_key.strip()}" - logger.info( - "Running job in LocalBackend. To start processing, run: `" - f"DSTACK_BACKEND=local " - "DSTACK_RUNNER_LOG_LEVEL=6 " - f"DSTACK_RUNNER_VERSION={get_dstack_runner_version()} " - f"DSTACK_IMAGE_NAME={job.job_spec.image_name} " - f'DSTACK_PUBLIC_SSH_KEY="{authorized_keys}" ./shim --dev docker --keep-container`', - ) return LaunchedInstanceInfo( instance_id="local", ip_address="127.0.0.1", diff --git a/src/dstack/_internal/core/backends/nebius/compute.py b/src/dstack/_internal/core/backends/nebius/compute.py index 7503503c7..fd1f6b3f7 100644 --- a/src/dstack/_internal/core/backends/nebius/compute.py +++ b/src/dstack/_internal/core/backends/nebius/compute.py @@ -77,14 +77,11 @@ def run_job( ), metadata={ "user-data": get_user_data( - backend=BackendType.NEBIUS, - image_name=job.job_spec.image_name, authorized_keys=[ run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip(), - ], - registry_auth_required=job.job_spec.registry_auth is not None, - ), + ] + ) }, disk_size_gb=disk_size, image_id=image_id, diff --git a/src/dstack/_internal/core/backends/tensordock/compute.py b/src/dstack/_internal/core/backends/tensordock/compute.py index 6c95115d4..bc1ff780f 100644 --- a/src/dstack/_internal/core/backends/tensordock/compute.py +++ b/src/dstack/_internal/core/backends/tensordock/compute.py @@ -50,13 +50,7 @@ def run_job( project_ssh_private_key: str, ) -> LaunchedInstanceInfo: commands = get_shim_commands( - backend=BackendType.TENSORDOCK, - image_name=job.job_spec.image_name, - authorized_keys=[ - run.run_spec.ssh_key_pub.strip(), - project_ssh_public_key.strip(), - ], - registry_auth_required=job.job_spec.registry_auth is not None, + authorized_keys=[run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] ) try: resp = self.api_client.deploy_single( diff --git a/src/dstack/_internal/core/models/backends/base.py b/src/dstack/_internal/core/models/backends/base.py index 56fcdb585..47286e6ef 100644 --- a/src/dstack/_internal/core/models/backends/base.py +++ b/src/dstack/_internal/core/models/backends/base.py @@ -26,6 +26,7 @@ class BackendType(str, enum.Enum): KUBERNETES = "kubernetes" LAMBDA = "lambda" LOCAL = "local" + REMOTE = "remote" # TODO: replace for LOCAL NEBIUS = "nebius" TENSORDOCK = "tensordock" VASTAI = "vastai" diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index fde057b75..6c319d5fb 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -1,9 +1,11 @@ from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import RegistryAuth +from dstack._internal.server.services.docker import DockerImage from dstack._internal.utils.common import pretty_resources @@ -60,6 +62,27 @@ class SSHConnectionParams(BaseModel): port: int +class SSHKey(BaseModel): + public: str + private: Optional[str] = None + + +class DockerConfig(BaseModel): + registry_auth: Optional[RegistryAuth] + image: Optional[DockerImage] + + +class InstanceConfiguration(BaseModel): + project_name: str + instance_name: str # unique in pool + ssh_keys: List[SSHKey] + job_docker_config: Optional[DockerConfig] + user: str # dstack user name + + def get_public_keys(self) -> List[str]: + return [ssh_key.public.strip() for ssh_key in self.ssh_keys] + + class LaunchedInstanceInfo(BaseModel): instance_id: str region: str @@ -67,8 +90,8 @@ class LaunchedInstanceInfo(BaseModel): username: str ssh_port: int # could be different from 22 for some backends dockerized: bool # True if backend starts shim - ssh_proxy: Optional[SSHConnectionParams] - backend_data: Optional[str] # backend-specific data in json + ssh_proxy: Optional[SSHConnectionParams] = Field(default=None) + backend_data: Optional[str] = Field(default=None) # backend-specific data in json class InstanceAvailability(Enum): @@ -76,6 +99,15 @@ class InstanceAvailability(Enum): AVAILABLE = "available" NOT_AVAILABLE = "not_available" NO_QUOTA = "no_quota" + READY = "ready" + BUSY = "busy" + + def is_available(self) -> bool: + return self in { + InstanceAvailability.UNKNOWN, + InstanceAvailability.AVAILABLE, + InstanceAvailability.READY, + } class InstanceOffer(BaseModel): diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py new file mode 100644 index 000000000..beb623a00 --- /dev/null +++ b/src/dstack/_internal/core/models/pools.py @@ -0,0 +1,32 @@ +import datetime +from typing import List, Optional + +from pydantic import BaseModel + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceType +from dstack._internal.core.models.runs import InstanceStatus, JobStatus + + +class Pool(BaseModel): + name: str + default: bool + created_at: datetime.datetime + total_instances: int + available_instances: int + + +class Instance(BaseModel): + backend: BackendType + instance_type: InstanceType + name: str + job_name: Optional[str] = None + job_status: Optional[JobStatus] = None + hostname: str + status: InstanceStatus + price: float + + +class PoolInstances(BaseModel): + name: str + instances: List[Instance] diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 9677c734f..ce031f814 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -9,6 +9,8 @@ from dstack._internal.core.models.common import ForbidExtra DEFAULT_RETRY_LIMIT = 3600 +DEFAULT_POOL_NAME = "default-pool" +DEFAULT_TERMINATION_IDLE_TIME = 5 * 60 # 5 minutes by default class SpotPolicy(str, Enum): @@ -17,6 +19,16 @@ class SpotPolicy(str, Enum): AUTO = "auto" +class CreationPolicy(str, Enum): + REUSE = "reuse" + REUSE_OR_CREATE = "reuse-or-create" + + +class TerminationPolicy(str, Enum): + DONT_DESTROY = "dont-destroy" + DESTROY_AFTER_IDLE = "destroy-after-idle" + + def parse_duration(v: Optional[Union[int, str]]) -> Optional[int]: if v is None: return None @@ -94,6 +106,21 @@ class Profile(ForbidExtra): default: Annotated[ bool, Field(description="If set to true, `dstack run` will use this profile by default.") ] = False + pool_name: Annotated[ + Optional[str], + Field(description="The name of the pool. If not set, dstack will use the default name."), + ] = None + instance_name: Annotated[Optional[str], Field(description="The name of the instance")] + creation_policy: Annotated[ + Optional[CreationPolicy], Field(description="The policy for using instances from the pool") + ] + termination_policy: Annotated[ + Optional[TerminationPolicy], Field(description="The policy for termination instances") + ] + termination_idle_time: Annotated[ + int, + Field(description="Seconds to wait before destroying the instance"), + ] = DEFAULT_TERMINATION_IDLE_TIME _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py index cb7b0613f..ef718f524 100644 --- a/src/dstack/_internal/core/models/resources.py +++ b/src/dstack/_internal/core/models/resources.py @@ -42,11 +42,11 @@ def _post_validate(cls, values): raise ValueError(f"Invalid range order: {min}..{max}") return values - def __str__(self): + def __str__(self) -> str: min = self.min if self.min is not None else "" max = self.max if self.max is not None else "" if min == max: - return f"{min}" + return str(min) return f"{min}..{max}" @@ -191,7 +191,7 @@ class ResourcesSpec(ForbidExtra): cpu (Optional[Range[int]]): The number of CPUs memory (Optional[Range[Memory]]): The size of RAM memory (e.g., `"16GB"`) gpu (Optional[GPUSpec]): The GPU spec - shm_size (Optional[Range[Memory]]): The of shared memory (e.g., `"8GB"`). If you are using parallel communicating processes (e.g., dataloaders in PyTorch), you may need to configure this. + shm_size (Optional[Range[Memory]]): The size of shared memory (e.g., `"8GB"`). If you are using parallel communicating processes (e.g., dataloaders in PyTorch), you may need to configure this. disk (Optional[DiskSpec]): The disk spec """ diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 7eb8638a3..d14ebe9c4 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from enum import Enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Sequence from pydantic import UUID4, BaseModel, Field from typing_extensions import Annotated @@ -23,8 +23,8 @@ class AppSpec(BaseModel): port: int map_to_port: Optional[int] app_name: str - url_path: Optional[str] - url_query_params: Optional[Dict[str, str]] + url_path: Optional[str] = None + url_query_params: Optional[Dict[str, str]] = None class JobStatus(str, Enum): @@ -120,6 +120,7 @@ class JobSpec(BaseModel): requirements: Requirements retry_policy: RetryPolicy working_dir: str + pool_name: Optional[str] class JobProvisioningData(BaseModel): @@ -216,3 +217,27 @@ class RunPlan(BaseModel): user: str run_spec: RunSpec job_plans: List[JobPlan] + + +class InstanceStatus(str, Enum): + PENDING = "pending" + CREATING = "creating" + STARTING = "starting" + READY = "ready" + BUSY = "busy" + TERMINATING = "terminating" + TERMINATED = "terminated" + FAILED = "failed" + + @property + def finished_statuses(cls) -> Sequence["InstanceStatus"]: + return (cls.TERMINATED, cls.FAILED) + + def is_finished(self): + return self in self.finished_statuses + + def is_started(self): + return not self.is_finished() + + def is_available(self) -> bool: + return self in (self.READY, self.BUSY) diff --git a/src/dstack/_internal/core/services/configs/__init__.py b/src/dstack/_internal/core/services/configs/__init__.py index f4cd82164..c7005066f 100644 --- a/src/dstack/_internal/core/services/configs/__init__.py +++ b/src/dstack/_internal/core/services/configs/__init__.py @@ -8,7 +8,7 @@ from pydantic import ValidationError from rich import print -from dstack._internal.cli.utils.common import colors, confirm_ask +from dstack._internal.cli.utils.common import confirm_ask from dstack._internal.core.models.config import GlobalConfig, ProjectConfig, RepoConfig from dstack._internal.core.models.repos.base import RepoType from dstack._internal.utils.common import get_dstack_dir @@ -127,9 +127,7 @@ def update_default_project( ( default_project is None or default - or confirm_ask( - f"Update the default project in [{colors['code']}]{config_dir}[/{colors['code']}]?" - ) + or confirm_ask(f"Update the default project in [code]{config_dir}[/]?") ) if not no_default else False @@ -139,4 +137,4 @@ def update_default_project( name=project_name, url=url, token=token, default=set_it_as_default ) config_manager.save() - print(f"Configuration updated at [{colors['code']}]{config_dir}[/{colors['code']}]") + print(f"Configuration updated at [code]{config_dir}[/]") diff --git a/src/dstack/_internal/core/services/ssh/ports.py b/src/dstack/_internal/core/services/ssh/ports.py index b65d6f163..3d81d0f11 100644 --- a/src/dstack/_internal/core/services/ssh/ports.py +++ b/src/dstack/_internal/core/services/ssh/ports.py @@ -66,6 +66,9 @@ def dict(self) -> Dict[int, int]: d[remote_port] = self.sockets[remote_port].getsockname()[1] return d + def __str__(self) -> str: + return f"" + @staticmethod def _listen(port: int) -> Optional[socket.socket]: try: diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index 9e20f71c3..f8ecb6435 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -141,13 +141,13 @@ def __init__( id_rsa_path: PathLike, control_sock_path: Optional[str] = None, ): - self.temp_dir = tempfile.TemporaryDirectory() if not control_sock_path else None + if control_sock_path is None: + self.temp_dir = tempfile.TemporaryDirectory() + control_sock_path = os.path.join(self.temp_dir.name, "control.sock") super().__init__( host=host, id_rsa_path=id_rsa_path, ports=ports, - control_sock_path=os.path.join(self.temp_dir.name, "control.sock") - if not control_sock_path - else control_sock_path, + control_sock_path=control_sock_path, options={}, ) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 8ad67c0e9..af6356a32 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -17,6 +17,7 @@ backends, gateways, logs, + pools, projects, repos, runs, @@ -129,6 +130,7 @@ def add_no_api_version_check_routes(paths: List[str]): def register_routes(app: FastAPI): app.include_router(users.router) app.include_router(projects.router) + app.include_router(pools.router) app.include_router(backends.root_router) app.include_router(backends.project_router) app.include_router(repos.router) diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py index b24876dd4..770dd53fa 100644 --- a/src/dstack/_internal/server/background/__init__.py +++ b/src/dstack/_internal/server/background/__init__.py @@ -3,6 +3,10 @@ from dstack._internal.server.background.tasks.process_finished_jobs import process_finished_jobs from dstack._internal.server.background.tasks.process_pending_jobs import process_pending_jobs +from dstack._internal.server.background.tasks.process_pools import ( + process_pools, + terminate_idle_instance, +) from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs @@ -18,5 +22,7 @@ def start_background_tasks() -> AsyncIOScheduler: _scheduler.add_job(process_running_jobs, IntervalTrigger(seconds=2)) _scheduler.add_job(process_finished_jobs, IntervalTrigger(seconds=2)) _scheduler.add_job(process_pending_jobs, IntervalTrigger(seconds=10)) + _scheduler.add_job(process_pools, IntervalTrigger(seconds=10)) + _scheduler.add_job(terminate_idle_instance, IntervalTrigger(seconds=10)) _scheduler.start() return _scheduler diff --git a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py index 2eacbe8cb..8519dfdb9 100644 --- a/src/dstack/_internal/server/background/tasks/process_finished_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_finished_jobs.py @@ -1,15 +1,13 @@ from sqlalchemy import or_, select from sqlalchemy.orm import joinedload -from dstack._internal.core.models.runs import JobSpec, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobSpec, JobStatus from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import GatewayModel, JobModel from dstack._internal.server.services.gateways import gateway_connections_pool from dstack._internal.server.services.jobs import ( TERMINATING_PROCESSING_JOBS_IDS, TERMINATING_PROCESSING_JOBS_LOCK, - job_model_to_job_submission, - terminate_job_submission_instance, ) from dstack._internal.server.services.logging import job_log from dstack._internal.server.utils.common import run_async @@ -31,7 +29,7 @@ async def process_finished_jobs(): or_(JobModel.remove_at.is_(None), JobModel.remove_at < get_current_datetime()), ) .order_by(JobModel.last_processed_at.asc()) - .limit(1) # TODO(egor-s) process multiple at once + .limit(1) ) job_model = res.scalar() if job_model is None: @@ -46,10 +44,13 @@ async def process_finished_jobs(): async def _process_job(job_id): async with get_session_ctx() as session: res = await session.execute( - select(JobModel).where(JobModel.id == job_id).options(joinedload(JobModel.project)) + select(JobModel) + .where(JobModel.id == job_id) + .options(joinedload(JobModel.project)) + .options(joinedload(JobModel.instance)) + .options(joinedload(JobModel.run)) ) job_model = res.scalar_one() - job_submission = job_model_to_job_submission(job_model) job_spec = JobSpec.parse_raw(job_model.job_spec_data) if job_spec.gateway is not None: res = await session.execute( @@ -77,16 +78,14 @@ async def _process_job(job_id): logger.debug(*job_log("service is unregistered", job_model)) except Exception as e: logger.warning("failed to unregister service: %s", e) - try: - if job_submission.job_provisioning_data is not None: - await terminate_job_submission_instance( - project=job_model.project, - job_submission=job_submission, - ) - job_model.removed = True - logger.info(*job_log("marked as removed", job_model)) - except Exception as e: - job_model.removed = False - logger.error(*job_log("failed to terminate job instance: %s", job_model, e)) + + if job_model.instance is not None: + job_model.used_instance_id = job_model.instance.id + job_model.instance.status = InstanceStatus.READY + job_model.instance.last_job_processed_at = get_current_datetime() + job_model.instance = None + + job_model.removed = True job_model.last_processed_at = get_current_datetime() await session.commit() + logger.info(*job_log("marked as removed", job_model)) diff --git a/src/dstack/_internal/server/background/tasks/process_pools.py b/src/dstack/_internal/server/background/tasks/process_pools.py new file mode 100644 index 000000000..affea4aaf --- /dev/null +++ b/src/dstack/_internal/server/background/tasks/process_pools.py @@ -0,0 +1,181 @@ +import datetime +from datetime import timedelta +from typing import Dict +from uuid import UUID + +from pydantic import parse_raw_as +from sqlalchemy import select +from sqlalchemy.orm import joinedload + +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import InstanceModel +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.jobs import ( + PROCESSING_POOL_IDS, + PROCESSING_POOL_LOCK, + terminate_job_provisioning_data_instance, +) +from dstack._internal.server.services.runner import client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.utils.common import run_async +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +PENDING_JOB_RETRY_INTERVAL = timedelta(seconds=60) + +logger = get_logger(__name__) + + +async def process_pools() -> None: + async with get_session_ctx() as session: + async with PROCESSING_POOL_LOCK: + res = await session.scalars( + select(InstanceModel).where( + InstanceModel.status.in_( + [ + InstanceStatus.CREATING, + InstanceStatus.STARTING, + InstanceStatus.TERMINATING, + InstanceStatus.READY, + InstanceStatus.BUSY, + ] + ), + InstanceModel.id.not_in(PROCESSING_POOL_IDS), + ) + ) + instances = res.all() + if not instances: + return + + PROCESSING_POOL_IDS.update(i.id for i in instances) + + try: + for inst in instances: + if inst.status in ( + InstanceStatus.CREATING, + InstanceStatus.STARTING, + InstanceStatus.READY, + InstanceStatus.BUSY, + ): + await check_shim(inst.id) + if inst.status == InstanceStatus.TERMINATING: + await terminate(inst.id) + finally: + PROCESSING_POOL_IDS.difference_update(i.id for i in instances) + + +async def check_shim(instance_id: UUID) -> None: + async with get_session_ctx() as session: + instance = ( + await session.scalars( + select(InstanceModel) + .where(InstanceModel.id == instance_id) + .options(joinedload(InstanceModel.project)) + ) + ).one() + ssh_private_key = instance.project.ssh_private_key + job_provisioning_data = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) + + instance_health = instance_healthcheck(ssh_private_key, job_provisioning_data) + + logger.info("check instance %s status: %s", instance.name, instance_health) + + if instance_health: + if instance.status in (InstanceStatus.CREATING, InstanceStatus.STARTING): + instance.status = InstanceStatus.READY + await session.commit() + else: + if instance.status in (InstanceStatus.READY, InstanceStatus.BUSY): + instance.status = InstanceStatus.FAILED + await session.commit() + + +@runner_ssh_tunnel(ports=[client.REMOTE_SHIM_PORT], retries=1) +def instance_healthcheck(*, ports: Dict[int, int]) -> bool: + shim_client = client.ShimClient(port=ports[client.REMOTE_SHIM_PORT]) + resp = shim_client.healthcheck() + if resp is None: + return False # shim is not available yet + return resp.service == "dstack-shim" + + +async def terminate(instance_id: UUID) -> None: + async with get_session_ctx() as session: + instance = ( + await session.scalars( + select(InstanceModel) + .where(InstanceModel.id == instance_id) + .options(joinedload(InstanceModel.project)) + ) + ).one() + + # TODO: need lock + + jpd = parse_raw_as(JobProvisioningData, instance.job_provisioning_data) + BACKEND_TYPE = jpd.backend + backends = await backends_services.get_project_backends(project=instance.project) + backend = next((b for b in backends if b.TYPE in BACKEND_TYPE), None) + if backend is None: + raise ValueError(f"there is no backned {BACKEND_TYPE}") + + await run_async( + backend.compute().terminate_instance, jpd.instance_id, jpd.region, jpd.backend_data + ) + + instance.deleted = True + instance.deleted_at = get_current_datetime() + instance.finished_at = get_current_datetime() + instance.status = InstanceStatus.TERMINATED + + logger.info("instance %s terminated", instance.name) + + await session.commit() + + +async def terminate_idle_instance() -> None: + async with get_session_ctx() as session: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE, + InstanceModel.deleted == False, + InstanceModel.job == None, # noqa: E711 + ) + .options(joinedload(InstanceModel.project)) + ) + instances = res.scalars().all() + + # TODO: need lock + + for instance in instances: + last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc) + if instance.last_job_processed_at is not None: + last_time = instance.last_job_processed_at.replace(tzinfo=datetime.timezone.utc) + + idle_seconds = instance.termination_idle_time + delta = datetime.timedelta(seconds=idle_seconds) + + current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc) + + if last_time + delta < current_time: + jpd: JobProvisioningData = parse_raw_as( + JobProvisioningData, instance.job_provisioning_data + ) + await terminate_job_provisioning_data_instance( + project=instance.project, job_provisioning_data=jpd + ) + instance.deleted = True + instance.deleted_at = get_current_datetime() + instance.finished_at = get_current_datetime() + instance.status = InstanceStatus.TERMINATED + + idle_time = current_time - last_time + logger.info( + "instance %s terminated by termination policy: idle time %ss", + instance.name, + str(idle_time.seconds), + ) + + await session.commit() diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 2ab974894..34d987d88 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -3,6 +3,7 @@ from uuid import UUID import httpx +from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -11,7 +12,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import RegistryAuth from dstack._internal.core.models.repos import RemoteRepoCreds -from dstack._internal.core.models.runs import Job, JobErrorCode, JobStatus, Run +from dstack._internal.core.models.runs import Job, JobErrorCode, JobSpec, JobStatus, Run from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import ( GatewayModel, @@ -73,7 +74,9 @@ async def process_running_jobs(): async def _process_job(job_id: UUID): async with get_session_ctx() as session: - res = await session.execute(select(JobModel).where(JobModel.id == job_id)) + res = await session.execute( + select(JobModel).where(JobModel.id == job_id).options(joinedload(JobModel.instance)) + ) job_model = res.scalar_one() res = await session.execute( select(RunModel) @@ -89,6 +92,7 @@ async def _process_job(job_id: UUID): job = run.jobs[job_model.job_num] job_submission = job_model_to_job_submission(job_model) job_provisioning_data = job_submission.job_provisioning_data + server_ssh_private_key = project.ssh_private_key secrets = {} # TODO secrets repo_creds = repo_model_to_repo_head(repo_model, include_creds=True).repo_creds @@ -136,7 +140,9 @@ async def _process_job(job_id: UUID): secrets, repo_creds, ) - if not success: # check timeout + + if not success: + # check timeout if job_submission.age > _get_runner_timeout_interval( job_provisioning_data.backend ): @@ -149,6 +155,10 @@ async def _process_job(job_id: UUID): ) job_model.status = JobStatus.FAILED job_model.error_code = JobErrorCode.WAITING_RUNNER_LIMIT_EXCEEDED + job_model.used_instance_id = job_model.instance.id + job_model.instance.last_job_processed_at = common_utils.get_current_datetime() + job_model.instance = None + else: # fails are not acceptable if initial_status == JobStatus.PULLING: logger.debug( @@ -184,6 +194,7 @@ async def _process_job(job_id: UUID): run_model, job_model, ) + if not success: # kill the job logger.warning( *job_log( @@ -194,6 +205,10 @@ async def _process_job(job_id: UUID): ) job_model.status = JobStatus.FAILED job_model.error_code = JobErrorCode.INTERRUPTED_BY_NO_CAPACITY + job_model.used_instance_id = job_model.instance.id + job_model.instance.last_job_processed_at = common_utils.get_current_datetime() + job_model.instance = None + if job.is_retry_active(): if job_submission.job_provisioning_data.instance_type.resources.spot: new_job_model = create_job_model_for_new_submission( @@ -202,6 +217,7 @@ async def _process_job(job_id: UUID): status=JobStatus.PENDING, ) session.add(new_job_model) + # job will be terminated by process_finished_jobs if ( @@ -276,6 +292,7 @@ def _process_provisioning_no_shim( Returns: is successful """ + runner_client = client.RunnerClient(port=ports[client.REMOTE_RUNNER_PORT]) resp = runner_client.healthcheck() if resp is None: @@ -308,18 +325,30 @@ def _process_provisioning_with_shim( Returns: is successful """ + job_spec = parse_raw_as(JobSpec, job_model.job_spec_data) + shim_client = client.ShimClient(port=ports[client.REMOTE_SHIM_PORT]) + resp = shim_client.healthcheck() if resp is None: logger.debug(*job_log("shim is not available yet", job_model)) return False # shim is not available yet + if registry_auth is not None: logger.debug(*job_log("authenticating to the registry...", job_model)) interpolate = VariablesInterpolator({"secrets": secrets}).interpolate - shim_client.registry_auth( + shim_client.submit( username=interpolate(registry_auth.username), password=interpolate(registry_auth.password), + image_name=job_spec.image_name, + ) + else: + shim_client.submit( + username="", + password="", + image_name=job_spec.image_name, ) + job_model.status = JobStatus.PULLING logger.info(*job_log("now is pulling", job_model)) return True @@ -396,6 +425,7 @@ def _process_running( last_job_state = resp.job_states[-1] job_model.status = last_job_state.state if job_model.status == JobStatus.DONE: + job_model.run.status = JobStatus.DONE delay_job_instance_termination(job_model) logger.info(*job_log("now is %s", job_model, job_model.status.value)) return True diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index caca1a0a0..4e8fdb9d5 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -1,28 +1,40 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from uuid import UUID +from pydantic import parse_raw_as from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from dstack._internal.core.backends.base import Backend from dstack._internal.core.errors import BackendError -from dstack._internal.core.models.instances import LaunchedInstanceInfo +from dstack._internal.core.models.instances import ( + InstanceOfferWithAvailability, + LaunchedInstanceInfo, +) +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy from dstack._internal.core.models.runs import ( + InstanceStatus, Job, JobErrorCode, JobProvisioningData, JobStatus, Run, + RunSpec, ) from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.models import InstanceModel, JobModel, PoolModel, RunModel from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services.jobs import ( SUBMITTED_PROCESSING_JOBS_IDS, SUBMITTED_PROCESSING_JOBS_LOCK, ) from dstack._internal.server.services.logging import job_log +from dstack._internal.server.services.pools import ( + filter_pool_instances, + get_pool_instances, + list_project_pool_models, +) from dstack._internal.server.services.runs import run_model_to_run from dstack._internal.server.utils.common import run_async from dstack._internal.utils import common as common_utils @@ -74,10 +86,79 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): ) run_model = res.scalar_one() project_model = run_model.project + + # check default pool + pool = project_model.default_pool + if pool is None: + # TODO: get_or_create_default_pool... + pools = await list_project_pool_models(session, job_model.project) + for pool_item in pools: + if pool_item.id == job_model.project.default_pool_id: + pool = pool_item + if pool_item.name == DEFAULT_POOL_NAME: + pool = pool_item + if pool is None: + pool = PoolModel( + name=DEFAULT_POOL_NAME, + project=project_model, + ) + session.add(pool) + await session.commit() + await session.refresh(pool) + + if pool.id is not None: + project_model.default_pool_id = pool.id + + run_spec = parse_raw_as(RunSpec, run_model.run_spec) + profile = run_spec.profile + run_pool = profile.pool_name + if run_pool is None: + run_pool = pool.name + + # pool capacity + + pool_instances = await get_pool_instances(session, project_model, run_pool) + relevant_instances = filter_pool_instances( + pool_instances, profile, run_spec.configuration.resources, status=InstanceStatus.READY + ) + + if relevant_instances: + sorted_instances = sorted(relevant_instances, key=lambda instance: instance.name) + instance = sorted_instances[0] + + # need lock + instance.status = InstanceStatus.BUSY + instance.job = job_model + + logger.info(*job_log("now is provisioning", job_model)) + job_model.job_provisioning_data = instance.job_provisioning_data + job_model.status = JobStatus.PROVISIONING + job_model.last_processed_at = common_utils.get_current_datetime() + + await session.commit() + + return + run = run_model_to_run(run_model) job = run.jobs[job_model.job_num] + + if profile.creation_policy == CreationPolicy.REUSE: + logger.debug(*job_log("reuse instance failed", job_model)) + if job.is_retry_active(): + logger.debug(*job_log("now is pending because retry is active", job_model)) + job_model.status = JobStatus.PENDING + else: + job_model.status = JobStatus.FAILED + job_model.error_code = JobErrorCode.FAILED_TO_START_DUE_TO_NO_CAPACITY + job_model.last_processed_at = common_utils.get_current_datetime() + await session.commit() + return + + # create a new cloud instance backends = await backends_services.get_project_backends(project=run_model.project) - job_provisioning_data = await _run_job( + + # TODO: create VM (backend.compute().create_instance) + job_provisioning_data, offer = await _run_job( job_model=job_model, run=run, job=job, @@ -85,10 +166,30 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): project_ssh_public_key=project_model.ssh_public_key, project_ssh_private_key=project_model.ssh_private_key, ) - if job_provisioning_data is not None: + if job_provisioning_data is not None and offer is not None: logger.info(*job_log("now is provisioning", job_model)) + job_model.job_provisioning_data = job_provisioning_data.json() job_model.status = JobStatus.PROVISIONING + + im = InstanceModel( + name=job.job_spec.job_name, # TODO: make new name + project=project_model, + pool=pool, + created_at=common_utils.get_current_datetime(), + started_at=common_utils.get_current_datetime(), + status=InstanceStatus.BUSY, + job_provisioning_data=job_provisioning_data.json(), + offer=offer.json(), + termination_policy=profile.termination_policy, + termination_idle_time=profile.termination_idle_time, + job=job_model, + backend=offer.backend, + price=offer.price, + region=offer.region, + ) + session.add(im) + else: logger.debug(*job_log("provisioning failed", job_model)) if job.is_retry_active(): @@ -108,16 +209,19 @@ async def _run_job( backends: List[Backend], project_ssh_public_key: str, project_ssh_private_key: str, -) -> Optional[JobProvisioningData]: +) -> Tuple[Optional[JobProvisioningData], Optional[InstanceOfferWithAvailability]]: if run.run_spec.profile.backends is not None: backends = [b for b in backends if b.TYPE in run.run_spec.profile.backends] + try: + requirements = job.job_spec.requirements offers = await backends_services.get_instance_offers( - backends, job, exclude_not_available=True + backends, requirements, exclude_not_available=True ) except BackendError as e: logger.warning(*job_log("failed to get instance offers: %s", job_model, repr(e))) - return None + return (None, None) + # Limit number of offers tried to prevent long-running processing # in case all offers fail. for backend, offer in offers[:15]: @@ -153,7 +257,7 @@ async def _run_job( ) continue else: - return JobProvisioningData( + job_provisioning_data = JobProvisioningData( backend=backend.TYPE, instance_type=offer.instance, instance_id=launched_instance_info.instance_id, @@ -166,4 +270,6 @@ async def _run_job( ssh_proxy=launched_instance_info.ssh_proxy, backend_data=launched_instance_info.backend_data, ) - return None + + return (job_provisioning_data, offer) + return (None, None) diff --git a/src/dstack/_internal/server/migrations/versions/27d3e55759fa_add_pools.py b/src/dstack/_internal/server/migrations/versions/27d3e55759fa_add_pools.py new file mode 100644 index 000000000..8869d3491 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/27d3e55759fa_add_pools.py @@ -0,0 +1,151 @@ +"""add pools + +Revision ID: 27d3e55759fa +Revises: d3e8af4786fa +Create Date: 2024-02-12 14:27:52.035476 + +""" +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "27d3e55759fa" +down_revision = "d3e8af4786fa" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "pools", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("name", sa.String(length=50), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_pools_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_pools")), + ) + op.create_table( + "instances", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("name", sa.String(length=50), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("deleted", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("pool_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column( + "status", + sa.Enum( + "PENDING", + "CREATING", + "STARTING", + "READY", + "BUSY", + "TERMINATING", + "TERMINATED", + "FAILED", + name="instancestatus", + ), + nullable=False, + ), + sa.Column("status_message", sa.String(length=50), nullable=True), + sa.Column("started_at", sa.DateTime(), nullable=True), + sa.Column("finished_at", sa.DateTime(), nullable=True), + sa.Column("termination_policy", sa.String(length=50), nullable=True), + sa.Column("termination_idle_time", sa.Integer(), nullable=False), + sa.Column( + "backend", + sa.Enum( + "AWS", + "AZURE", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "TENSORDOCK", + "VASTAI", + name="backendtype", + ), + nullable=False, + ), + sa.Column("backend_data", sa.String(length=4000), nullable=True), + sa.Column("region", sa.String(length=2000), nullable=False), + sa.Column("price", sa.Float(), nullable=False), + sa.Column("job_provisioning_data", sa.String(length=4000), nullable=False), + sa.Column("offer", sa.String(length=4000), nullable=False), + sa.Column("resource_spec_data", sa.String(length=4000), nullable=True), + sa.Column("job_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True), + sa.Column("last_job_processed_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(["job_id"], ["jobs.id"], name=op.f("fk_instances_job_id_jobs")), + sa.ForeignKeyConstraint( + ["pool_id"], ["pools.id"], name=op.f("fk_instances_pool_id_pools") + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_instances_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_instances")), + ) + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "used_instance_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=True, + ) + ) + + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "default_pool_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=True, + ) + ) + batch_op.create_foreign_key( + batch_op.f("fk_projects_default_pool_id_pools"), + "pools", + ["default_pool_id"], + ["id"], + ondelete="SET NULL", + use_alter=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_projects_default_pool_id_pools"), type_="foreignkey" + ) + batch_op.drop_column("default_pool_id") + + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("used_instance_id") + + op.drop_table("instances") + op.drop_table("pools") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index dc62443f3..8833c8292 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -7,6 +7,7 @@ Boolean, DateTime, Enum, + Float, ForeignKey, Integer, MetaData, @@ -19,8 +20,9 @@ from sqlalchemy_utils import UUIDType from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.profiles import DEFAULT_TERMINATION_IDLE_TIME, TerminationPolicy from dstack._internal.core.models.repos.base import RepoType -from dstack._internal.core.models.runs import JobErrorCode, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobErrorCode, JobStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.utils.common import get_current_datetime @@ -79,6 +81,13 @@ class ProjectModel(BaseModel): foreign_keys=[default_gateway_id], lazy="selectin" ) + default_pool_id: Mapped[Optional[UUIDType]] = mapped_column( + ForeignKey("pools.id", use_alter=True, ondelete="SET NULL"), nullable=True + ) + default_pool: Mapped["PoolModel"] = relationship( + foreign_keys=[default_pool_id], lazy="selectin" + ) + class MemberModel(BaseModel): __tablename__ = "members" @@ -183,6 +192,8 @@ class JobModel(BaseModel): # `removed` is used to ensure that the instance is killed after the job is finished removed: Mapped[bool] = mapped_column(Boolean, default=False) remove_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="job") + used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False)) class GatewayModel(BaseModel): @@ -230,3 +241,63 @@ class GatewayComputeModel(BaseModel): ssh_public_key: Mapped[str] = mapped_column(Text) deleted: Mapped[bool] = mapped_column(Boolean, server_default=false()) + + +class PoolModel(BaseModel): + __tablename__ = "pools" + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(String(50)) + created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) + deleted: Mapped[bool] = mapped_column(Boolean, default=False) + deleted_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + + project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) + project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) + + instances: Mapped[List["InstanceModel"]] = relationship(back_populates="pool", lazy="selectin") + + +class InstanceModel(BaseModel): + __tablename__ = "instances" + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(String(50)) + created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime) + deleted: Mapped[bool] = mapped_column(Boolean, default=False) + deleted_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + + project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) + project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id], single_parent=True) + + pool_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("pools.id")) + pool: Mapped["PoolModel"] = relationship(back_populates="instances", single_parent=True) + + status: Mapped[InstanceStatus] = mapped_column(Enum(InstanceStatus)) + + # VM + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=get_current_datetime) + + termination_policy: Mapped[Optional[TerminationPolicy]] = mapped_column(String(50)) + termination_idle_time: Mapped[int] = mapped_column( + Integer, default=DEFAULT_TERMINATION_IDLE_TIME + ) + + backend: Mapped[BackendType] = mapped_column(Enum(BackendType)) + backend_data: Mapped[Optional[str]] = mapped_column(String(4000)) + region: Mapped[str] = mapped_column(String(2000)) + price: Mapped[float] = mapped_column(Float) + + job_provisioning_data: Mapped[str] = mapped_column(String(4000)) + + offer: Mapped[str] = mapped_column(String(4000)) + + # current job + job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id")) + job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="immediate") + last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py new file mode 100644 index 000000000..3b8aa026c --- /dev/null +++ b/src/dstack/_internal/server/routers/pools.py @@ -0,0 +1,125 @@ +from typing import List, Tuple + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +import dstack._internal.core.models.pools as models +import dstack._internal.server.schemas.pools as schemas +import dstack._internal.server.services.pools as pools +from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.server.db import get_session +from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest +from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember +from dstack._internal.server.services.runs import ( + abort_runs_of_pool, + list_project_runs, + run_model_to_run, +) + +router = APIRouter(prefix="/api/project/{project_name}/pool", tags=["pool"]) + + +@router.post("/list") +async def list_pool( + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +) -> List[models.Pool]: + _, project = user_project + return await pools.list_project_pool(session=session, project=project) + + +@router.post("/remove") +async def remove_instance( + body: schemas.RemoveInstanceRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> None: + _, project_model = user_project + await pools.remove_instance( + session, project_model, body.pool_name, body.instance_name, body.force + ) + + +@router.post("/set_default") +async def set_default_pool( + body: schemas.SetDefaultPoolRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> bool: + _, project_model = user_project + return await pools.set_default_pool(session, project_model, body.pool_name) + + +@router.post("/delete") +async def delete_pool( + body: schemas.DeletePoolRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> None: + pool_name = body.name + _, project_model = user_project + + if body.force: + await abort_runs_of_pool(session, project_model, pool_name) + await pools.delete_pool(session, project_model, pool_name) + return + + # check active runs + runs = await list_project_runs(session, project_model, repo_id=None) + active_runs = [] + for run_model in runs: + if run_model.status.is_finished(): + continue + run = run_model_to_run(run_model) + run_pool_name = run.run_spec.profile.pool_name + if run_pool_name == pool_name: + active_runs.append(run) + if active_runs: + return + + # TODO: check active instances + + await pools.delete_pool(session, project_model, pool_name) + + +@router.post("/create") +async def create_pool( + body: schemas.CreatePoolRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> None: + _, project = user_project + await pools.create_pool_model(session=session, project=project, name=body.name) + + +@router.post("/show") +async def show_pool( + body: schemas.ShowPoolRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), +) -> models.PoolInstances: + _, project = user_project + instances = await pools.show_pool(session, project, pool_name=body.name) + if instances is None: + raise ResourceNotExistsError("Pool is not found") + return instances + + +@router.post("/add_remote") +async def add_instance( + body: AddRemoteInstanceRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +) -> bool: + _, project = user_project + result = await pools.add_remote( + session, + project=project, + resources=body.resources, + profile=body.profile, + instance_name=body.instance_name, + host=body.host, + port=body.port, + ) + return result diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 318ad4490..29d47920b 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -3,12 +3,17 @@ from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.core.errors import ResourceNotExistsError, ServerClientError +from dstack._internal.core.models.instances import InstanceOfferWithAvailability +from dstack._internal.core.models.pools import Instance from dstack._internal.core.models.runs import Run, RunPlan from dstack._internal.server.db import get_session from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.runs import ( + CreateInstanceRequest, DeleteRunsRequest, + GetOffersRequest, + GetRunPlanRequest, GetRunRequest, ListRunsRequest, StopRunsRequest, @@ -16,6 +21,10 @@ ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember from dstack._internal.server.services import runs +from dstack._internal.server.services.pools import ( + generate_instance_name, + get_or_create_default_pool_by_name, +) root_router = APIRouter( prefix="/api/runs", @@ -58,9 +67,52 @@ async def get_run( return run +@project_router.post("/get_offers") +async def get_offers( + body: GetOffersRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +) -> Tuple[str, List[InstanceOfferWithAvailability]]: + _, project = user_project + + active_pool = await get_or_create_default_pool_by_name( + session, project, body.profile.pool_name + ) + + offers = await runs.get_run_plan_by_requirements(project, body.profile, body.requirements) + instances = [instance for _, instance in offers] + + return active_pool.name, instances + + +@project_router.post("/create_instance") +async def create_instance( + body: CreateInstanceRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +) -> Instance: + user, project = user_project + instance_name = await generate_instance_name( + session=session, project=project, pool_name=body.pool_name + ) + instance = await runs.create_instance( + session=session, + project=project, + user=user, + ssh_key=body.ssh_key, + pool_name=body.pool_name, + instance_name=instance_name, + profile=body.profile, + requirements=body.requirements, + ) + if instance is None: + raise ServerClientError(msg="Failed to create an instance") + return instance + + @project_router.post("/get_plan") async def get_run_plan( - body: SubmitRunRequest, + body: GetRunPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> RunPlan: diff --git a/src/dstack/_internal/server/schemas/pools.py b/src/dstack/_internal/server/schemas/pools.py new file mode 100644 index 000000000..f4eccc6f1 --- /dev/null +++ b/src/dstack/_internal/server/schemas/pools.py @@ -0,0 +1,26 @@ +from typing import Optional + +from pydantic import BaseModel + + +class DeletePoolRequest(BaseModel): + name: str + force: bool + + +class CreatePoolRequest(BaseModel): + name: str + + +class ShowPoolRequest(BaseModel): + name: Optional[str] + + +class RemoveInstanceRequest(BaseModel): + pool_name: str + instance_name: str + force: bool = False + + +class SetDefaultPoolRequest(BaseModel): + pool_name: str diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 58b87c053..97672ca40 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -63,8 +63,10 @@ class SubmitBody(BaseModel): class HealthcheckResponse(BaseModel): service: str + version: str -class RegistryAuthBody(BaseModel): +class DockerImageBody(BaseModel): username: str password: str + image_name: str diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 38064a1db..a81371461 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -2,7 +2,10 @@ from pydantic import BaseModel -from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.instances import SSHKey +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import Requirements, RunSpec class ListRunsRequest(BaseModel): @@ -18,6 +21,26 @@ class GetRunPlanRequest(BaseModel): run_spec: RunSpec +class GetOffersRequest(BaseModel): + profile: Profile + requirements: Requirements + + +class CreateInstanceRequest(BaseModel): + pool_name: str + profile: Profile + requirements: Requirements + ssh_key: SSHKey + + +class AddRemoteInstanceRequest(BaseModel): + instance_name: Optional[str] + host: str + port: str + resources: ResourcesSpec + profile: Profile + + class SubmitRunRequest(BaseModel): run_spec: RunSpec diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 09a54cc0a..3a05e262c 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -20,10 +20,9 @@ ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( - InstanceAvailability, InstanceOfferWithAvailability, ) -from dstack._internal.core.models.runs import Job +from dstack._internal.core.models.runs import Requirements from dstack._internal.server.models import BackendModel, ProjectModel from dstack._internal.server.services.backends.configurators.base import Configurator from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED @@ -292,27 +291,22 @@ async def get_project_backend_model_by_type( return None -_NOT_AVAILABLE = {InstanceAvailability.NOT_AVAILABLE, InstanceAvailability.NO_QUOTA} - - async def get_instance_offers( - backends: List[Backend], job: Job, exclude_not_available: bool = False + backends: List[Backend], requirements: Requirements, exclude_not_available: bool = False ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: """ Returns list of instances satisfying minimal resource requirements sorted by price """ - tasks = [ - run_async(backend.compute().get_offers, job.job_spec.requirements) for backend in backends - ] + tasks = [run_async(backend.compute().get_offers, requirements) for backend in backends] offers_by_backend = [ [ (backend, offer) for offer in backend_offers - if not exclude_not_available or offer.availability not in _NOT_AVAILABLE + if not exclude_not_available or offer.availability.is_available() ] for backend, backend_offers in zip(backends, await asyncio.gather(*tasks)) ] # Merge preserving order for every backend offers = heapq.merge(*offers_by_backend, key=lambda i: i[1].price) - # Put NOT_AVAILABLE and NO_QUOTA instances at the end, do not sort by price - return sorted(offers, key=lambda i: i[1].availability in _NOT_AVAILABLE) + # Put NOT_AVAILABLE, NO_QUOTA, and BUSY instances at the end, do not sort by price + return sorted(offers, key=lambda i: not i[1].availability.is_available()) diff --git a/src/dstack/_internal/server/services/docker.py b/src/dstack/_internal/server/services/docker.py index db2687391..e16349df2 100644 --- a/src/dstack/_internal/server/services/docker.py +++ b/src/dstack/_internal/server/services/docker.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass from enum import Enum from typing import Optional import requests +from pydantic import BaseModel manifests_media_types = [ "application/vnd.oci.image.index.v1+json", @@ -12,8 +12,11 @@ ] -@dataclass(frozen=True) -class DockerImage: +class DockerImage(BaseModel): + class Config: + frozen = True + + image: str registry: Optional[str] repo: str tag: str @@ -115,7 +118,7 @@ def parse_image_name(image: str) -> DockerImage: registry = components[0] repo = "/".join(components[1:]) - return DockerImage(registry, repo, tag, digest) + return DockerImage(image=image, registry=registry, repo=repo, tag=tag, digest=digest) def is_host(s: str) -> bool: diff --git a/src/dstack/_internal/server/services/gateways/pool.py b/src/dstack/_internal/server/services/gateways/pool.py index 11215c2a6..fd3a96d55 100644 --- a/src/dstack/_internal/server/services/gateways/pool.py +++ b/src/dstack/_internal/server/services/gateways/pool.py @@ -8,7 +8,7 @@ class GatewayConnectionsPool: - def __init__(self): + def __init__(self) -> None: self._connections: Dict[str, GatewayConnection] = {} self._lock = asyncio.Lock() self.server_port: Optional[int] = None @@ -39,7 +39,7 @@ async def remove(self, hostname: str) -> bool: await stop_task return True - async def remove_all(self): + async def remove_all(self) -> None: async with self._lock: await asyncio.gather( *(conn.tunnel.stop() for conn in self._connections.values()), @@ -54,4 +54,4 @@ async def all(self) -> List[GatewayConnection]: return list(self._connections.values()) -gateway_connections_pool = GatewayConnectionsPool() +gateway_connections_pool: GatewayConnectionsPool = GatewayConnectionsPool() diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 7f8f5148d..c60d8be3b 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -39,6 +39,10 @@ RUNNING_PROCESSING_JOBS_LOCK = asyncio.Lock() RUNNING_PROCESSING_JOBS_IDS = set() +PROCESSING_POOL_LOCK = asyncio.Lock() +PROCESSING_POOL_IDS = set() + + TERMINATING_PROCESSING_JOBS_LOCK = asyncio.Lock() TERMINATING_PROCESSING_JOBS_IDS = set() @@ -102,11 +106,7 @@ async def stop_job( job_submission = job_model_to_job_submission(job_model) if new_status == JobStatus.TERMINATED and job_model.status == JobStatus.RUNNING: try: - await run_async( - _stop_runner, - job_submission, - project.ssh_private_key, - ) + await run_async(_stop_runner, job_submission, project.ssh_private_key) # delay termination for 15 seconds to allow the runner to stop gracefully delay_job_instance_termination(job_model) except SSHError: @@ -119,20 +119,19 @@ async def stop_job( logger.info(*job_log("%s by user", job_model, new_status.value)) -async def terminate_job_submission_instance( - project: ProjectModel, - job_submission: JobSubmission, +async def terminate_job_provisioning_data_instance( + project: ProjectModel, job_provisioning_data: JobProvisioningData ): backend = await get_project_backend_by_type( project=project, - backend_type=job_submission.job_provisioning_data.backend, + backend_type=job_provisioning_data.backend, ) - logger.debug("Terminating runner instance %s", job_submission.job_provisioning_data.hostname) + logger.debug("Terminating runner instance %s", job_provisioning_data.hostname) await run_async( backend.compute().terminate_instance, - job_submission.job_provisioning_data.instance_id, - job_submission.job_provisioning_data.region, - job_submission.job_provisioning_data.backend_data, + job_provisioning_data.instance_id, + job_provisioning_data.region, + job_provisioning_data.backend_data, ) diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index d2a6b25fc..7a2387fb7 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -22,6 +22,16 @@ from dstack._internal.core.services.ssh.ports import filter_reserved_ports +def get_default_python_verison() -> str: + version_info = sys.version_info + return PythonVersion(f"{version_info.major}.{version_info.minor}").value + + +def get_default_image(python_version: str) -> str: + # TODO: non-cuda image + return f"dstackai/base:py{python_version}-{version.base_image}-cuda-12.1" + + class JobConfigurator(ABC): TYPE: ConfigurationType @@ -43,6 +53,7 @@ def get_job_specs(self) -> List[JobSpec]: requirements=self._requirements(), retry_policy=self._retry_policy(), working_dir=self._working_dir(), + pool_name=self._pool_name(), ) return [job_spec] @@ -113,8 +124,7 @@ def _home_dir(self) -> Optional[str]: def _image_name(self) -> str: if self.run_spec.configuration.image is not None: return self.run_spec.configuration.image - # TODO: non-cuda image - return f"dstackai/base:py{self._python()}-{version.base_image}-cuda-12.1" + return get_default_image(self._python()) def _max_duration(self) -> Optional[int]: if self.run_spec.profile.max_duration is None: @@ -140,8 +150,10 @@ def _working_dir(self) -> str: def _python(self) -> str: if self.run_spec.configuration.python is not None: return self.run_spec.configuration.python.value - version_info = sys.version_info - return PythonVersion(f"{version_info.major}.{version_info.minor}").value + return get_default_python_verison() + + def _pool_name(self): + return self.run_spec.profile.pool_name def _join_shell_commands(commands: List[str], env: Optional[Dict[str, str]] = None) -> str: diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py new file mode 100644 index 000000000..4260f0636 --- /dev/null +++ b/src/dstack/_internal/server/services/pools.py @@ -0,0 +1,417 @@ +import asyncio +from datetime import timezone +from typing import Dict, List, Optional, Sequence + +import gpuhunt +from pydantic import parse_raw_as +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.backends.base.offers import ( + offer_to_catalog_item, + requirements_to_query_filter, +) +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + Gpu, + InstanceAvailability, + InstanceOffer, + InstanceOfferWithAvailability, + InstanceType, + Resources, +) +from dstack._internal.core.models.pools import Instance, Pool, PoolInstances +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, SpotPolicy +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, Requirements +from dstack._internal.server import settings +from dstack._internal.server.models import InstanceModel, PoolModel, ProjectModel +from dstack._internal.utils import common as common_utils +from dstack._internal.utils import random_names +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def list_project_pool(session: AsyncSession, project: ProjectModel) -> List[Pool]: + pools = list(await list_project_pool_models(session=session, project=project)) + if not pools: + pool = await create_pool_model(session, project, DEFAULT_POOL_NAME) + pools.append(pool) + return [pool_model_to_pool(p) for p in pools] + + +async def get_pool( + session: AsyncSession, project: ProjectModel, pool_name: str +) -> Optional[PoolModel]: + pool = ( + await session.scalars( + select(PoolModel).where( + PoolModel.name == pool_name, + PoolModel.project_id == project.id, + PoolModel.deleted == False, + ) + ) + ).one_or_none() + return pool + + +async def get_or_create_default_pool_by_name( + session: AsyncSession, project: ProjectModel, pool_name: Optional[str] +) -> PoolModel: + active_pool = None + if pool_name is None: + default_pool = None + pools = [ + pool + for pool in (await list_project_pool_models(session, project)) + if project.default_pool == pool + ] + if pools: + default_pool = pools[0] + if not default_pool: + default_pool = await create_pool_model(session, project, DEFAULT_POOL_NAME) + active_pool = default_pool + else: + active_pool = await get_pool(session, project, pool_name) + if active_pool is None: + active_pool = await create_pool_model(session, project, DEFAULT_POOL_NAME) + return active_pool + + +def pool_model_to_pool(pool_model: PoolModel) -> Pool: + total = 0 + available = 0 + for instance in pool_model.instances: + if not instance.deleted: + total += 1 + if instance.status.is_available(): + available += 1 + return Pool( + name=pool_model.name, + default=pool_model.project.default_pool_id == pool_model.id, + created_at=pool_model.created_at.replace(tzinfo=timezone.utc), + total_instances=total, + available_instances=available, + ) + + +async def create_pool_model(session: AsyncSession, project: ProjectModel, name: str) -> PoolModel: + pools = await session.scalars( + select(PoolModel) + .where(PoolModel.name == name, PoolModel.project == project, PoolModel.deleted == False) + .options(joinedload(PoolModel.instances)) + ) + if pools.unique().all(): + raise ValueError("duplicate pool name") # TODO: return error with description + + pool = PoolModel( + name=name, + project_id=project.id, + ) + + if project.default_pool is None: + project.default_pool = pool + + session.add(pool) + await session.commit() + await session.refresh(pool) + + return pool + + +async def list_project_pool_models( + session: AsyncSession, project: ProjectModel +) -> Sequence[PoolModel]: + pools = await session.scalars( + select(PoolModel) + .where(PoolModel.project_id == project.id, PoolModel.deleted == False) + .options(joinedload(PoolModel.instances)) + ) + return pools.unique().all() + + +async def set_default_pool(session: AsyncSession, project: ProjectModel, pool_name: str) -> bool: + pool = ( + await session.scalars( + select(PoolModel).where( + PoolModel.name == pool_name, + PoolModel.project == project, + PoolModel.deleted == False, + ) + ) + ).one_or_none() + + if pool is None: + return False + project.default_pool = pool + + await session.commit() + return True + + +async def remove_instance( + session: AsyncSession, + project: ProjectModel, + pool_name: str, + instance_name: str, + force: bool, +) -> None: + pool = await get_pool(session, project, pool_name) + + if pool is None: + logger.warning("Couldn't find pool") + return + + # TODO: need lock + terminated = False + for instance in pool.instances: + if instance.name == instance_name: + if force or instance.job_id is None: + instance.status = InstanceStatus.TERMINATING + terminated = True + + if not terminated: + logger.warning("Couldn't find instance to terminate") + + await session.commit() + + +async def delete_pool(session: AsyncSession, project: ProjectModel, pool_name: str) -> None: + """delete the pool and set the default pool to project""" + + default_pool: Optional[PoolModel] = None + default_pool_removed = False + + for pool in await list_project_pool_models(session, project): + if pool.name == DEFAULT_POOL_NAME: + default_pool = pool + + if pool_name == pool.name: + if project.default_pool_id == pool.id: + default_pool_removed = True + pool.deleted = True + pool.deleted_at = get_current_datetime() + + if default_pool_removed: + if default_pool is not None: + project.default_pool = default_pool + else: + await create_pool_model(session, project, DEFAULT_POOL_NAME) + + await session.commit() + + +async def list_deleted_pools( + session: AsyncSession, project_model: ProjectModel +) -> Sequence[PoolModel]: + pools = await session.scalars( + select(PoolModel).where( + PoolModel.project_id == project_model.id, PoolModel.deleted == True + ) + ) + return pools.all() + + +def instance_model_to_instance(instance_model: InstanceModel) -> Instance: + offer: InstanceOfferWithAvailability = parse_raw_as( + InstanceOfferWithAvailability, instance_model.offer + ) + jpd: JobProvisioningData = parse_raw_as( + JobProvisioningData, instance_model.job_provisioning_data + ) + + instance = Instance( + backend=offer.backend, + name=instance_model.name, + instance_type=jpd.instance_type, + hostname=jpd.hostname, + status=instance_model.status, + price=offer.price, + ) + if instance_model.job is not None: + instance.job_name = instance_model.job.job_name + instance.job_status = instance_model.job.status + + return instance + + +async def show_pool( + session: AsyncSession, project: ProjectModel, pool_name: Optional[str] +) -> Optional[PoolInstances]: + """Show active instances in the pool (specified or default). Return None if the pool is not found.""" + if pool_name is None: + pool = project.default_pool + else: + pool = await get_pool(session, project, pool_name) + + if pool is None: + return None + return PoolInstances( + name=pool.name, + instances=[instance_model_to_instance(i) for i in pool.instances if not i.deleted], + ) + + +async def get_pool_instances( + session: AsyncSession, project: ProjectModel, pool_name: str +) -> List[InstanceModel]: + res = await session.execute( + select(PoolModel).where( + PoolModel.name == pool_name, + PoolModel.project_id == project.id, + PoolModel.deleted == False, + ) + ) + result = res.unique().scalars().one_or_none() + if result is None: + return [] + instances: List[InstanceModel] = result.instances + return instances + + +async def get_instances_by_pool_id(session: AsyncSession, pool_id: str) -> List[InstanceModel]: + res = await session.execute( + select(PoolModel) + .where( + PoolModel.id == pool_id, + ) + .options(joinedload(PoolModel.instances)) + ) + result = res.unique().scalars().one_or_none() + if result is None: + return [] + instances: List[InstanceModel] = result.instances + return instances + + +_GENERATE_POOL_NAME_LOCK: Dict[str, asyncio.Lock] = {} + + +async def generate_instance_name( + session: AsyncSession, + project: ProjectModel, + pool_name: str, +) -> str: + lock = _GENERATE_POOL_NAME_LOCK.setdefault(project.name, asyncio.Lock()) + async with lock: + pool_instances: List[InstanceModel] = await get_pool_instances(session, project, pool_name) + names = {g.name for g in pool_instances} + while True: + name = f"{random_names.generate_name()}" + if name not in names: + return name + + +async def add_remote( + session: AsyncSession, + resources: ResourcesSpec, + project: ProjectModel, + profile: Profile, + instance_name: Optional[str], + host: str, + port: str, +) -> bool: + pool_model = await get_or_create_default_pool_by_name(session, project, profile.pool_name) + + profile.pool_name = pool_model.name + if instance_name is None: + instance_name = await generate_instance_name(session, project, profile.pool_name) + + gpus = [] + if resources.gpu is not None: + gpus = [ + Gpu(name=resources.gpu.name, memory_mib=resources.gpu.memory) + ] * resources.gpu.count.min + + instance_resource = Resources( + cpus=resources.cpu.min, memory_mib=resources.memory.min, gpus=gpus, spot=False + ) + + local = JobProvisioningData( + backend=BackendType.REMOTE, + instance_type=InstanceType(name="local", resources=instance_resource), + instance_id=instance_name, + hostname=host, + region="", + price=0, + username="", + ssh_port=22, + dockerized=False, + backend_data="", + ssh_proxy=None, + ) + offer = InstanceOfferWithAvailability( + backend=BackendType.REMOTE, + instance=InstanceType( + name="instance", + resources=instance_resource, + ), + region="", + price=0.0, + availability=InstanceAvailability.AVAILABLE, + ) + + im = InstanceModel( + name=instance_name, + project=project, + pool=pool_model, + created_at=common_utils.get_current_datetime(), + started_at=common_utils.get_current_datetime(), + status=InstanceStatus.PENDING, + job_provisioning_data=local.json(), + offer=offer.json(), + termination_policy=profile.termination_policy, + termination_idle_time=profile.termination_idle_time, + ) + session.add(im) + await session.commit() + + return True + + +def filter_pool_instances( + pool_instances: List[InstanceModel], + profile: Profile, + resources: ResourcesSpec, + *, + status: Optional[InstanceStatus] = None, +) -> List[InstanceModel]: + """ + Filter instances by `instance_name`, `backends`, `resources`, `spot_policy`, `max_price`, `status` + """ + instances: List[InstanceModel] = [] + candidates: List[InstanceModel] = [] + for instance in pool_instances: + if profile.instance_name is not None and instance.name != profile.instance_name: + continue + if status is not None and instance.status != status: + continue + + # TODO: remove on prod + if settings.LOCAL_BACKEND_ENABLED and instance.backend == BackendType.LOCAL: + instances.append(instance) + continue + + if profile.backends is not None and instance.backend not in profile.backends: + continue + candidates.append(instance) + + requirements = Requirements( + resources=resources, + max_price=profile.max_price, + spot={ + None: None, + SpotPolicy.AUTO: None, + SpotPolicy.SPOT: True, + SpotPolicy.ONDEMAND: False, + }[profile.spot_policy], + ) + query_filter = requirements_to_query_filter(requirements) + for instance in candidates: + catalog_item = offer_to_catalog_item(parse_raw_as(InstanceOffer, instance.offer)) + if gpuhunt.matches(catalog_item, query_filter): + instances.append(instance) + return instances diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 386a84449..05f4e8330 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -78,11 +78,13 @@ async def create_project(session: AsyncSession, user: UserModel, project_name: s user=user, project_role=ProjectRole.ADMIN, ) - project = await get_project_model_by_name_or_error(session=session, project_name=project_name) + project_model = await get_project_model_by_name_or_error( + session=session, project_name=project_name + ) for hook in _CREATE_PROJECT_HOOKS: - await hook(session, project) - await session.refresh(project) # a hook may change project - return project_model_to_project(project) + await hook(session, project_model) + await session.refresh(project_model) # a hook may change project + return project_model_to_project(project_model) async def delete_projects( diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 7520c67e9..bd4045ae5 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -9,9 +9,9 @@ from dstack._internal.core.models.repos.remote import RemoteRepoCreds from dstack._internal.core.models.runs import JobSpec, RunSpec from dstack._internal.server.schemas.runner import ( + DockerImageBody, HealthcheckResponse, PullResponse, - RegistryAuthBody, SubmitBody, ) @@ -99,10 +99,13 @@ def healthcheck(self) -> Optional[HealthcheckResponse]: except requests.exceptions.RequestException: return None - def registry_auth(self, username: str, password: str): + def submit(self, username: str, password: str, image_name: str): + post_body = DockerImageBody( + username=username, password=password, image_name=image_name + ).dict() resp = requests.post( - self._url("/api/registry_auth"), - json=RegistryAuthBody(username=username, password=password).dict(), + self._url("/api/submit"), + json=post_body, ) resp.raise_for_status() diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 959bff6f1..42bbd53d9 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -10,6 +10,7 @@ from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.core.services.ssh.tunnel import RunnerTunnel from dstack._internal.server.services.jobs import get_runner_ports +from dstack._internal.server.settings import LOCAL_BACKEND_ENABLED from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -33,6 +34,12 @@ def wrapper( Returns: is successful """ + + if LOCAL_BACKEND_ENABLED: + # without SSH + port_map = {p: p for p in ports} + return func(*args, ports=port_map, **kwargs) + func_kwargs_names = [ p.name for p in inspect.signature(func).parameters.values() diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 301218747..a44aa52b6 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -1,8 +1,9 @@ import asyncio import itertools +import math import uuid from datetime import timezone -from typing import List, Optional +from typing import List, Optional, Tuple, cast import pydantic from sqlalchemy import select, update @@ -11,13 +12,27 @@ import dstack._internal.server.services.gateways as gateways import dstack._internal.utils.common as common_utils -from dstack._internal.core.errors import RepoDoesNotExistError, ServerClientError +from dstack._internal.core.backends.base import Backend +from dstack._internal.core.errors import BackendError, RepoDoesNotExistError, ServerClientError +from dstack._internal.core.models.instances import ( + DockerConfig, + InstanceAvailability, + InstanceConfiguration, + InstanceOfferWithAvailability, + LaunchedInstanceInfo, + SSHKey, +) +from dstack._internal.core.models.pools import Instance +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, CreationPolicy, Profile from dstack._internal.core.models.runs import ( + InstanceStatus, Job, JobPlan, + JobProvisioningData, JobSpec, JobStatus, JobSubmission, + Requirements, Run, RunPlan, RunSpec, @@ -25,15 +40,35 @@ ServiceModelInfo, ) from dstack._internal.core.models.users import GlobalRole -from dstack._internal.server.models import JobModel, ProjectModel, RunModel, UserModel +from dstack._internal.server.models import ( + InstanceModel, + JobModel, + ProjectModel, + RunModel, + UserModel, +) from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import pools as pools_services from dstack._internal.server.services import repos as repos_services +from dstack._internal.server.services.docker import parse_image_name from dstack._internal.server.services.jobs import ( get_jobs_from_run_spec, job_model_to_job_submission, stop_job, ) +from dstack._internal.server.services.jobs.configurators.base import ( + get_default_image, + get_default_python_verison, +) +from dstack._internal.server.services.pools import ( + create_pool_model, + filter_pool_instances, + get_or_create_default_pool_by_name, + get_pool_instances, + instance_model_to_instance, +) from dstack._internal.server.services.projects import list_project_models, list_user_project_models +from dstack._internal.server.utils.common import run_async from dstack._internal.utils.logging import get_logger from dstack._internal.utils.random_names import generate_name @@ -118,35 +153,208 @@ async def get_run( return run_model_to_run(run_model) -async def get_run_plan( +async def get_run_plan_by_requirements( + project: ProjectModel, + profile: Profile, + requirements: Requirements, + exclude_not_available=False, +) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: + backends: List[Backend] = await backends_services.get_project_backends(project=project) + + if profile.backends is not None: + backends = [b for b in backends if b.TYPE in profile.backends] + + offers = await backends_services.get_instance_offers( + backends=backends, + requirements=requirements, + exclude_not_available=exclude_not_available, + ) + + return offers + + +async def create_instance( session: AsyncSession, project: ProjectModel, user: UserModel, - run_spec: RunSpec, + ssh_key: SSHKey, + pool_name: str, + instance_name: str, + profile: Profile, + requirements: Requirements, +) -> Optional[Instance]: + offers = await get_run_plan_by_requirements( + project, profile, requirements, exclude_not_available=True + ) + + if not offers: + return + + user_ssh_key = ssh_key + project_ssh_key = SSHKey( + public=project.ssh_public_key.strip(), + private=project.ssh_private_key.strip(), + ) + + image = parse_image_name(get_default_image(get_default_python_verison())) + instance_config = InstanceConfiguration( + project_name=project.name, + instance_name=instance_name, + ssh_keys=[user_ssh_key, project_ssh_key], + job_docker_config=DockerConfig( + image=image, + registry_auth=None, + ), + user=user.name, + ) + + pool = await pools_services.get_pool(session, project, pool_name) + + if pool is None: + pool = await create_pool_model(session, project, pool_name) + + for backend, instance_offer in offers: + logger.debug( + "trying %s in %s/%s for $%0.4f per hour", + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + instance_offer.price, + ) + try: + launched_instance_info: LaunchedInstanceInfo = await run_async( + backend.compute().create_instance, + instance_offer, + instance_config, + ) + except BackendError as e: + logger.warning( + "%s launch in %s/%s failed: %s", + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + repr(e), + ) + continue + + job_provisioning_data = JobProvisioningData( + backend=backend.TYPE, + instance_type=instance_offer.instance, + instance_id=launched_instance_info.instance_id, + hostname=launched_instance_info.ip_address, + region=launched_instance_info.region, + price=instance_offer.price, + username=launched_instance_info.username, + ssh_port=launched_instance_info.ssh_port, + dockerized=launched_instance_info.dockerized, + backend_data=launched_instance_info.backend_data, + ssh_proxy=None, + ) + + # types of queries + # 1. Get all available instance + # 2. Get job's instance (process job) + # 3. Get instance's jobs history + + im = InstanceModel( + name=instance_name, + project=project, + pool=pool, + created_at=common_utils.get_current_datetime(), + started_at=common_utils.get_current_datetime(), + status=InstanceStatus.STARTING, + backend=backend.TYPE, + region=instance_offer.region, + price=instance_offer.price, + # job_id: Optional[FK] (current job) + # ip address + # ssh creds: user, port, dockerized + # real resources + spot (exact) / instance offer + # backend + backend data + # region + # price (for querying) + # termination policy + # creation policy + job_provisioning_data=job_provisioning_data.json(), + # TODO: instance provisioning + offer=cast(InstanceOfferWithAvailability, instance_offer).json(), + termination_policy=profile.termination_policy, + termination_idle_time=profile.termination_idle_time, + ) + session.add(im) + await session.commit() + + return instance_model_to_instance(im) + + +async def get_run_plan( + session: AsyncSession, project: ProjectModel, user: UserModel, run_spec: RunSpec ) -> RunPlan: + profile = run_spec.profile + + # TODO: get_or_create_default_pool + + pool_name = profile.pool_name + if profile.pool_name is None: + try: + pool_name = project.default_pool.name + except Exception: + pool_name = DEFAULT_POOL_NAME # TODO: get pool from project + + pool_instances = [ + instance + for instance in (await get_pool_instances(session, project, pool_name)) + if not instance.deleted + ] + + pool_offers: List[InstanceOfferWithAvailability] = [] + + for instance in filter_pool_instances( + pool_instances, profile, run_spec.configuration.resources + ): + offer = pydantic.parse_raw_as(InstanceOfferWithAvailability, instance.offer) + if instance.status == InstanceStatus.READY: + offer.availability = InstanceAvailability.READY + else: + offer.availability = InstanceAvailability.BUSY + pool_offers.append(offer) + backends = await backends_services.get_project_backends(project=project) - if run_spec.profile.backends is not None: - backends = [b for b in backends if b.TYPE in run_spec.profile.backends] + if profile.backends is not None: + backends = [b for b in backends if b.TYPE in profile.backends] + run_name = run_spec.run_name # preserve run_name run_spec.run_name = "dry-run" # will regenerate jobs on submission jobs = get_jobs_from_run_spec(run_spec) job_plans = [] + + creation_policy = profile.creation_policy + for job in jobs: - offers = await backends_services.get_instance_offers( - backends=backends, - job=job, - exclude_not_available=False, - ) - for backend, offer in offers: - offer.backend = backend.TYPE - offers = [offer for _, offer in offers] + job_offers: List[InstanceOfferWithAvailability] = [] + job_offers.extend(pool_offers) + + if creation_policy is None or creation_policy == CreationPolicy.REUSE_OR_CREATE: + requirements = job.job_spec.requirements + offers = await backends_services.get_instance_offers( + backends=backends, + requirements=requirements, + exclude_not_available=False, + ) + for backend, offer in offers: + offer.backend = backend.TYPE + job_offers.extend(offer for _, offer in offers) + + # TODO(egor-s): merge job_offers and pool_offers based on (availability, use/create, price) job_plan = JobPlan( job_spec=job.job_spec, - offers=offers[:50], - total_offers=len(offers), - max_price=max((offer.price for offer in offers), default=None), + offers=job_offers[:50], + total_offers=len(job_offers), + max_price=max((offer.price for offer in job_offers), default=None), ) job_plans.append(job_plan) + + run_spec.profile.pool_name = pool_name # write pool name back for the client run_spec.run_name = run_name # restore run_name run_plan = RunPlan( project_name=project.name, user=user.name, run_spec=run_spec, job_plans=job_plans @@ -167,9 +375,11 @@ async def submit_run( ) if repo is None: raise RepoDoesNotExistError.with_id(run_spec.repo_id) + backends = await backends_services.get_project_backends(project) if len(backends) == 0: raise ServerClientError("No backends configured") + if run_spec.run_name is None: run_spec.run_name = await _generate_run_name( session=session, @@ -177,6 +387,9 @@ async def submit_run( ) else: await delete_runs(session=session, project=project, runs_names=[run_spec.run_name]) + + pool = await get_or_create_default_pool_by_name(session, project, run_spec.profile.pool_name) + run_model = RunModel( id=uuid.uuid4(), project_id=project.id, @@ -188,10 +401,13 @@ async def submit_run( run_spec=run_spec.json(), ) session.add(run_model) + jobs = get_jobs_from_run_spec(run_spec) if run_spec.configuration.type == "service": await gateways.register_service_jobs(session, project, run_spec.run_name, jobs) + for job in jobs: + job.job_spec.pool_name = pool.name job_model = create_job_model_for_new_submission( run_model=run_model, job=job, @@ -200,6 +416,7 @@ async def submit_run( session.add(job_model) await session.commit() await session.refresh(run_model) + run = run_model_to_run(run_model) return run @@ -238,11 +455,13 @@ async def stop_runs( new_status = JobStatus.ABORTED res = await session.execute( - select(JobModel).where( + select(JobModel) + .where( JobModel.project_id == project.id, JobModel.run_name.in_(runs_names), JobModel.status.not_in(JobStatus.finished_statuses()), ) + .options(joinedload(JobModel.instance)) ) job_models = res.scalars().all() for job_model in job_models: @@ -296,10 +515,13 @@ def run_model_to_run(run_model: RunModel, include_job_submissions: bool = True) submissions.append(job_model_to_job_submission(job_model)) if job_spec is not None: jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) + run_spec = RunSpec.parse_raw(run_model.run_spec) + latest_job_submission = None if include_job_submissions: latest_job_submission = jobs[0].job_submissions[-1] + run = Run( id=run_model.id, project_name=run_model.project.name, @@ -346,7 +568,7 @@ async def _generate_run_name( def _get_run_cost(run: Run) -> float: - run_cost = sum( + run_cost = math.fsum( _get_job_submission_cost(submission) for job in run.jobs for submission in job.job_submissions @@ -387,3 +609,18 @@ def _get_run_service(run: Run) -> Optional[ServiceInfo]: ), model=model, ) + + +async def abort_runs_of_pool(session: AsyncSession, project_model: ProjectModel, pool_name: str): + runs = await list_project_runs(session, project_model, repo_id=None) + active_run_names = [] + for run_model in runs: + if run_model.status.is_finished(): + continue + + run = run_model_to_run(run_model) + run_pool_name = run.run_spec.profile.pool_name + if run_pool_name == pool_name: + active_run_names.append(run.run_spec.run_name) + + await stop_runs(session, project_model, active_run_names, abort=True) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 1db28b8ac..ef376d34a 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -9,16 +9,29 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import DevEnvironmentConfiguration from dstack._internal.core.models.instances import InstanceType, Resources -from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.profiles import ( + DEFAULT_POOL_NAME, + DEFAULT_TERMINATION_IDLE_TIME, + Profile, +) from dstack._internal.core.models.repos.base import RepoType from dstack._internal.core.models.repos.local import LocalRunRepoData -from dstack._internal.core.models.runs import JobErrorCode, JobProvisioningData, JobStatus, RunSpec +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import ( + InstanceStatus, + JobErrorCode, + JobProvisioningData, + JobStatus, + RunSpec, +) from dstack._internal.core.models.users import GlobalRole from dstack._internal.server.models import ( BackendModel, GatewayComputeModel, GatewayModel, + InstanceModel, JobModel, + PoolModel, ProjectModel, RepoModel, RunModel, @@ -136,8 +149,10 @@ async def create_repo( def get_run_spec( run_name: str, repo_id: str, - profile: Optional[Profile] = Profile(name="default"), + profile: Optional[Profile] = None, ) -> RunSpec: + if profile is None: + profile = Profile(name="default") return RunSpec( run_name=run_name, repo_id=repo_id, @@ -189,6 +204,7 @@ async def create_job( last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), error_code: Optional[JobErrorCode] = None, job_provisioning_data: Optional[JobProvisioningData] = None, + instance: Optional[InstanceModel] = None, ) -> JobModel: run_spec = RunSpec.parse_raw(run.run_spec) job_spec = get_job_specs_from_run_spec(run_spec)[0] @@ -205,6 +221,7 @@ async def create_job( error_code=error_code, job_spec_data=job_spec.json(), job_provisioning_data=job_provisioning_data.json() if job_provisioning_data else None, + instance=instance if instance is not None else None, ) session.add(job) await session.commit() @@ -226,6 +243,7 @@ def get_job_provisioning_data() -> JobProvisioningData: ssh_port=22, dockerized=False, backend_data=None, + ssh_proxy=None, ) @@ -271,3 +289,47 @@ async def create_gateway_compute( session.add(gateway_compute) await session.commit() return gateway_compute + + +async def create_pool( + session: AsyncSession, + project: ProjectModel, + pool_name: Optional[str] = None, +) -> PoolModel: + pool_name = pool_name if pool_name is not None else DEFAULT_POOL_NAME + pool = PoolModel( + name=pool_name, + project=project, + project_id=project.id, + ) + session.add(pool) + await session.commit() + return pool + + +async def create_instance( + session: AsyncSession, + project: ProjectModel, + pool: PoolModel, + status: InstanceStatus, + resources: ResourcesSpec, +) -> InstanceModel: + im = InstanceModel( + name="test_instance", + pool=pool, + project=project, + status=status, + job_provisioning_data='{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "ssh_proxy": null, "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}', + offer='{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 2, "memory_mib": 12000, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}', + price=1, + region="eu-west", + backend=BackendType.DATACRUNCH, + termination_idle_time=DEFAULT_TERMINATION_IDLE_TIME, + ) + session.add(im) + await session.commit() + + # pool.instances.append(im) + # await session.commit() + + return im diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index 1c1490619..4855fb7b9 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -6,6 +6,7 @@ from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike from dstack.api._public.backends import BackendCollection +from dstack.api._public.pools import PoolCollection from dstack.api._public.repos import RepoCollection, get_ssh_keypair from dstack.api._public.runs import RunCollection from dstack.api.server import APIClient @@ -40,6 +41,7 @@ def __init__( self._repos = RepoCollection(api_client, project_name) self._backends = BackendCollection(api_client, project_name) self._runs = RunCollection(api_client, project_name, self) + self._pool = PoolCollection(api_client, project_name) if ssh_identity_file: self.ssh_identity_file = str(ssh_identity_file) else: @@ -95,3 +97,7 @@ def client(self) -> APIClient: @property def project(self) -> str: return self._project + + @property + def pool(self) -> PoolCollection: + return self._pool diff --git a/src/dstack/api/_public/pools.py b/src/dstack/api/_public/pools.py new file mode 100644 index 000000000..a496c2f4b --- /dev/null +++ b/src/dstack/api/_public/pools.py @@ -0,0 +1,41 @@ +from typing import List + +from dstack._internal.core.models.pools import Pool +from dstack.api.server import APIClient + + +class PoolInstance: + def __init__(self, api_client: APIClient, pool: Pool): + self._api_client = api_client + self._pool = pool + + @property + def name(self) -> str: + return self._pool.name + + def __str__(self) -> str: + return f"" + + def __repr__(self) -> str: + return f"" + + +class PoolCollection: + """ + Operations with pools + """ + + def __init__(self, api_client: APIClient, project: str): + self._api_client = api_client + self._project = project + + def list(self) -> List[PoolInstance]: + """ + List available pool in the project + + Returns: + pools + """ + list_raw_pool = self._api_client.pool.list(project_name=self._project) + list_pool = [PoolInstance(self._api_client, instance) for instance in list_raw_pool] + return list_pool diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index b96e4d956..1aaf085ad 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -7,7 +7,7 @@ from copy import copy from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union from websocket import WebSocketApp @@ -15,10 +15,19 @@ from dstack._internal.core.errors import ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration -from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy, SpotPolicy +from dstack._internal.core.models.instances import InstanceOfferWithAvailability, SSHKey +from dstack._internal.core.models.pools import Instance +from dstack._internal.core.models.profiles import ( + DEFAULT_TERMINATION_IDLE_TIME, + CreationPolicy, + Profile, + ProfileRetryPolicy, + SpotPolicy, + TerminationPolicy, +) from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.resources import ResourcesSpec -from dstack._internal.core.models.runs import JobSpec, RunPlan, RunSpec +from dstack._internal.core.models.runs import JobSpec, Requirements, RunPlan, RunSpec from dstack._internal.core.models.runs import JobStatus as RunStatus from dstack._internal.core.models.runs import Run as RunModel from dstack._internal.core.services.logs import URLReplacer @@ -265,6 +274,7 @@ def attach( if not control_sock_path_and_port_locks: self._ssh_attach.attach() self._ports_lock = None + return True def detach(self): @@ -355,6 +365,18 @@ def submit( ) return self.exec_plan(run_plan, repo, reserve_ports=reserve_ports) + def get_offers( + self, profile: Profile, requirements: Requirements + ) -> Tuple[str, List[InstanceOfferWithAvailability]]: + return self._api_client.runs.get_offers(self._project, profile, requirements) + + def create_instance( + self, pool_name: str, profile: Profile, requirements: Requirements, ssh_key: SSHKey + ) -> Instance: + return self._api_client.runs.create_instance( + self._project, pool_name, profile, requirements, ssh_key + ) + def get_plan( self, configuration: AnyRunConfiguration, @@ -368,6 +390,11 @@ def get_plan( max_price: Optional[float] = None, working_dir: Optional[str] = None, run_name: Optional[str] = None, + pool_name: Optional[str] = None, + instance_name: Optional[str] = None, + creation_policy: Optional[CreationPolicy] = None, + termination_policy: Optional[TerminationPolicy] = None, + termination_policy_idle: int = DEFAULT_TERMINATION_IDLE_TIME, ) -> RunPlan: # """ # Get run plan. Same arguments as `submit` @@ -378,10 +405,10 @@ def get_plan( if working_dir is None: working_dir = "." elif repo.repo_dir is not None: - working_dir = Path(repo.repo_dir) / working_dir - if not path_in_dir(working_dir, repo.repo_dir): + working_dir_path = Path(repo.repo_dir) / working_dir + if not path_in_dir(working_dir_path, repo.repo_dir): raise ConfigurationError("Working directory is outside of the repo") - working_dir = working_dir.relative_to(repo.repo_dir).as_posix() + working_dir = working_dir_path.relative_to(repo.repo_dir).as_posix() if configuration_path is None: configuration_path = "(python)" @@ -397,6 +424,11 @@ def get_plan( retry_policy=retry_policy, max_duration=max_duration, max_price=max_price, + pool_name=pool_name, + instance_name=instance_name, + creation_policy=creation_policy, + termination_policy=termination_policy, + termination_idle_time=termination_policy_idle, ) run_spec = RunSpec( run_name=run_name, diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index c5c9749cf..955bf956e 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -10,6 +10,7 @@ from dstack.api.server._backends import BackendsAPIClient from dstack.api.server._gateways import GatewaysAPIClient from dstack.api.server._logs import LogsAPIClient +from dstack.api.server._pools import PoolAPIClient from dstack.api.server._projects import ProjectsAPIClient from dstack.api.server._repos import ReposAPIClient from dstack.api.server._runs import RunsAPIClient @@ -34,6 +35,7 @@ class APIClient: runs: operations with runs logs: operations with logs gateways: operations with gateways + pools: operations with pools """ def __init__(self, base_url: str, token: str): @@ -82,8 +84,16 @@ def secrets(self) -> SecretsAPIClient: def gateways(self) -> GatewaysAPIClient: return GatewaysAPIClient(self._request) + @property + def pool(self) -> PoolAPIClient: + return PoolAPIClient(self._request) + def _request( - self, path: str, body: Optional[str] = None, raise_for_status: bool = True, **kwargs + self, + path: str, + body: Optional[str] = None, + raise_for_status: bool = True, + **kwargs, ) -> requests.Response: path = path.lstrip("/") if body is not None: diff --git a/src/dstack/api/server/_backends.py b/src/dstack/api/server/_backends.py index 85de8e7d9..2d2136abb 100644 --- a/src/dstack/api/server/_backends.py +++ b/src/dstack/api/server/_backends.py @@ -21,11 +21,15 @@ def config_values(self, config: AnyConfigInfoWithCredsPartial) -> AnyConfigValue resp = self._request("/api/backends/config_values", body=config.json()) return parse_obj_as(AnyConfigValues, resp.json()) - def create(self, project_name: str, config: AnyConfigInfoWithCreds) -> AnyConfigInfoWithCreds: + def create( + self, project_name: str, config: AnyConfigInfoWithCreds + ) -> AnyConfigInfoWithCredsPartial: resp = self._request(f"/api/project/{project_name}/backends/create", body=config.json()) return parse_obj_as(AnyConfigInfoWithCredsPartial, resp.json()) - def update(self, project_name: str, config: AnyConfigInfoWithCreds) -> AnyConfigInfoWithCreds: + def update( + self, project_name: str, config: AnyConfigInfoWithCreds + ) -> AnyConfigInfoWithCredsPartial: resp = self._request(f"/api/project/{project_name}/backends/update", body=config.json()) return parse_obj_as(AnyConfigInfoWithCredsPartial, resp.json()) diff --git a/src/dstack/api/server/_pools.py b/src/dstack/api/server/_pools.py new file mode 100644 index 000000000..fc5cdf7d2 --- /dev/null +++ b/src/dstack/api/server/_pools.py @@ -0,0 +1,61 @@ +from typing import List, Optional + +from pydantic import parse_obj_as + +import dstack._internal.server.schemas.pools as schemas_pools +from dstack._internal.core.models.pools import Pool, PoolInstances +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.server.schemas.runs import AddRemoteInstanceRequest +from dstack.api.server._group import APIClientGroup + + +class PoolAPIClient(APIClientGroup): + def list(self, project_name: str) -> List[Pool]: + resp = self._request(f"/api/project/{project_name}/pool/list") + result: List[Pool] = parse_obj_as(List[Pool], resp.json()) + return result + + def delete(self, project_name: str, pool_name: str, force: bool) -> None: + body = schemas_pools.DeletePoolRequest(name=pool_name, force=force) + self._request(f"/api/project/{project_name}/pool/delete", body=body.json()) + + def create(self, project_name: str, pool_name: str) -> None: + body = schemas_pools.CreatePoolRequest(name=pool_name) + self._request(f"/api/project/{project_name}/pool/create", body=body.json()) + + def show(self, project_name: str, pool_name: Optional[str]) -> PoolInstances: + body = schemas_pools.ShowPoolRequest(name=pool_name) + resp = self._request(f"/api/project/{project_name}/pool/show", body=body.json()) + pool: PoolInstances = parse_obj_as(PoolInstances, resp.json()) + return pool + + def remove(self, project_name: str, pool_name: str, instance_name: str, force: bool) -> None: + body = schemas_pools.RemoveInstanceRequest( + pool_name=pool_name, instance_name=instance_name, force=force + ) + self._request(f"/api/project/{project_name}/pool/remove", body=body.json()) + + def set_default(self, project_name: str, pool_name: str) -> bool: + body = schemas_pools.SetDefaultPoolRequest(pool_name=pool_name) + result = self._request(f"/api/project/{project_name}/pool/set_default", body=body.json()) + return bool(result.json()) + + def add_remote( + self, + project_name: str, + resources: ResourcesSpec, + profile: Profile, + instance_name: Optional[str], + host: str, + port: str, + ) -> bool: + body = AddRemoteInstanceRequest( + profile=profile, + instance_name=instance_name, + host=host, + port=port, + resources=resources, + ) + result = self._request(f"/api/project/{project_name}/pool/add_remote", body=body.json()) + return bool(result.json()) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 63bc79367..51ee77115 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -1,14 +1,20 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from pydantic import parse_obj_as -from dstack._internal.core.models.runs import Run, RunPlan, RunSpec +from dstack._internal.core.models.instances import InstanceOfferWithAvailability, SSHKey +from dstack._internal.core.models.pools import Instance +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.runs import Requirements, Run, RunPlan, RunSpec from dstack._internal.server.schemas.runs import ( + CreateInstanceRequest, DeleteRunsRequest, + GetOffersRequest, GetRunPlanRequest, GetRunRequest, ListRunsRequest, StopRunsRequest, + SubmitRunRequest, ) from dstack.api.server._group import APIClientGroup @@ -24,13 +30,34 @@ def get(self, project_name: str, run_name: str) -> Run: resp = self._request(f"/api/project/{project_name}/runs/get", body=body.json()) return parse_obj_as(Run, resp.json()) + def get_offers( + self, project_name: str, profile: Profile, requirements: Requirements + ) -> Tuple[str, List[InstanceOfferWithAvailability]]: + body = GetOffersRequest(profile=profile, requirements=requirements) + resp = self._request(f"/api/project/{project_name}/runs/get_offers", body=body.json()) + return parse_obj_as(Tuple[str, List[InstanceOfferWithAvailability]], resp.json()) + + def create_instance( + self, + project_name: str, + pool_name: str, + profile: Profile, + requirements: Requirements, + ssh_key: SSHKey, + ) -> Instance: + body = CreateInstanceRequest( + pool_name=pool_name, profile=profile, requirements=requirements, ssh_key=ssh_key + ) + resp = self._request(f"/api/project/{project_name}/runs/create_instance", body=body.json()) + return parse_obj_as(Instance, resp.json()) + def get_plan(self, project_name: str, run_spec: RunSpec) -> RunPlan: body = GetRunPlanRequest(run_spec=run_spec) resp = self._request(f"/api/project/{project_name}/runs/get_plan", body=body.json()) return parse_obj_as(RunPlan, resp.json()) def submit(self, project_name: str, run_spec: RunSpec) -> Run: - body = GetRunPlanRequest(run_spec=run_spec) + body = SubmitRunRequest(run_spec=run_spec) resp = self._request(f"/api/project/{project_name}/runs/submit", body=body.json()) return parse_obj_as(Run, resp.json()) diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index c539d3c6a..d645519e7 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -8,12 +8,14 @@ from dstack._internal.core.errors import SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType, Resources -from dstack._internal.core.models.runs import JobProvisioningData, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus from dstack._internal.server import settings from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs from dstack._internal.server.schemas.runner import HealthcheckResponse, JobStateEvent, PullResponse from dstack._internal.server.testing.common import ( + create_instance, create_job, + create_pool, create_project, create_repo, create_run, @@ -36,6 +38,7 @@ def get_job_provisioning_data(dockerized: bool) -> JobProvisioningData: ssh_port=22, dockerized=dockerized, backend_data=None, + ssh_proxy=None, ) @@ -110,7 +113,7 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): ) as RunnerClientMock: runner_client_mock = RunnerClientMock.return_value runner_client_mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner" + service="dstack-runner", version="0.0.1.dev2" ) await process_running_jobs() RunnerTunnelMock.assert_called_once() @@ -195,24 +198,31 @@ async def test_provisioning_shim(self, test_db, session: AsyncSession): user=user, ) job_provisioning_data = get_job_provisioning_data(dockerized=True) - job = await create_job( - session=session, - run=run, - status=JobStatus.PROVISIONING, - job_provisioning_data=job_provisioning_data, - ) + + with patch( + "dstack._internal.server.services.jobs.configurators.base.get_default_python_verison" + ) as PyVersion: + PyVersion.return_value = "3.11" + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + job_provisioning_data=job_provisioning_data, + ) with patch( "dstack._internal.server.services.runner.ssh.RunnerTunnel" ) as RunnerTunnelMock, patch( "dstack._internal.server.services.runner.client.ShimClient" ) as ShimClientMock: ShimClientMock.return_value.healthcheck.return_value = HealthcheckResponse( - service="dstack-shim" + service="dstack-shim", version="0.0.1.dev2" ) await process_running_jobs() RunnerTunnelMock.assert_called_once() ShimClientMock.return_value.healthcheck.assert_called_once() - ShimClientMock.return_value.registry_auth.assert_not_called() + ShimClientMock.return_value.submit.assert_called_once_with( + username="", password="", image_name="dstackai/base:py3.11-0.4rc4-cuda-12.1" + ) await session.refresh(job) assert job is not None assert job.status == JobStatus.PULLING @@ -246,7 +256,7 @@ async def test_pulling_shim(self, test_db, session: AsyncSession): "dstack._internal.server.services.runner.client.ShimClient" ) as ShimClientMock: RunnerTunnelMock.return_value.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner" + service="dstack-runner", version="0.0.1.dev2" ) await process_running_jobs() RunnerTunnelMock.assert_called_once() @@ -274,12 +284,20 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession): repo=repo, user=user, ) + instance = await create_instance( + session, + project, + await create_pool(session, project), + InstanceStatus.READY, + Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ) job_provisioning_data = get_job_provisioning_data(dockerized=True) job = await create_job( session=session, run=run, status=JobStatus.PULLING, job_provisioning_data=job_provisioning_data, + instance=instance, ) with patch( "dstack._internal.server.services.runner.ssh.RunnerTunnel" diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index bb0e175ee..498de90c6 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import pytest +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.backends.base import BackendType @@ -12,10 +13,15 @@ LaunchedInstanceInfo, Resources, ) -from dstack._internal.core.models.profiles import Profile, ProfileRetryPolicy -from dstack._internal.core.models.runs import JobStatus +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile, ProfileRetryPolicy +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs +from dstack._internal.server.models import JobModel +from dstack._internal.server.services.pools import ( + get_or_create_default_pool_by_name, +) from dstack._internal.server.testing.common import ( + create_instance, create_job, create_project, create_repo, @@ -23,6 +29,7 @@ create_user, get_run_spec, ) +from dstack.api._public.resources import Resources as MakeResources class TestProcessSubmittedJobs: @@ -67,22 +74,21 @@ async def test_provisiones_job(self, test_db, session: AsyncSession): session=session, run=run, ) + offer = InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance", + resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ), + region="us", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + ) with patch("dstack._internal.server.services.backends.get_project_backends") as m: backend_mock = Mock() m.return_value = [backend_mock] backend_mock.TYPE = BackendType.AWS - backend_mock.compute.return_value.get_offers.return_value = [ - InstanceOfferWithAvailability( - backend=BackendType.AWS, - instance=InstanceType( - name="instance", - resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), - ), - region="us", - price=1.0, - availability=InstanceAvailability.AVAILABLE, - ) - ] + backend_mock.compute.return_value.get_offers.return_value = [offer] backend_mock.compute.return_value.run_job.return_value = LaunchedInstanceInfo( instance_id="instance_id", region="us", @@ -95,10 +101,22 @@ async def test_provisiones_job(self, test_db, session: AsyncSession): m.assert_called_once() backend_mock.compute.return_value.get_offers.assert_called_once() backend_mock.compute.return_value.run_job.assert_called_once() + await session.refresh(job) assert job is not None assert job.status == JobStatus.PROVISIONING + await session.refresh(project) + assert project.default_pool.name == DEFAULT_POOL_NAME + + instance_offer = InstanceOfferWithAvailability.parse_raw( + project.default_pool.instances[0].offer + ) + assert offer == instance_offer + + pool_job_provisioning_data = project.default_pool.instances[0].job_provisioning_data + assert pool_job_provisioning_data == job.job_provisioning_data + @pytest.mark.asyncio async def test_transitions_job_with_retry_to_pending_on_no_capacity( self, test_db, session: AsyncSession @@ -132,10 +150,14 @@ async def test_transitions_job_with_retry_to_pending_on_no_capacity( with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: datetime_mock.return_value = datetime(2023, 1, 2, 3, 30, 0, tzinfo=timezone.utc) await process_submitted_jobs() + await session.refresh(job) assert job is not None assert job.status == JobStatus.PENDING + await session.refresh(project) + assert not project.default_pool.instances + @pytest.mark.asyncio async def test_transitions_job_with_outdated_retry_to_failed_on_no_capacity( self, test_db, session: AsyncSession @@ -169,6 +191,80 @@ async def test_transitions_job_with_outdated_retry_to_failed_on_no_capacity( with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: datetime_mock.return_value = datetime(2023, 1, 2, 5, 0, 0, tzinfo=timezone.utc) await process_submitted_jobs() + await session.refresh(job) assert job is not None assert job.status == JobStatus.FAILED + + await session.refresh(project) + assert not project.default_pool.instances + + @pytest.mark.asyncio + async def test_job_with_instance(self, test_db, session: AsyncSession): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo( + session, + project_id=project.id, + ) + pool = await get_or_create_default_pool_by_name(session, project, pool_name=None) + resources = MakeResources(cpu=2, memory="12GB") + await create_instance(session, project, pool, InstanceStatus.READY, resources) + await session.refresh(pool) + run = await create_run( + session, + project=project, + repo=repo, + user=user, + ) + job_provisioning_data = JobProvisioningData( + backend=BackendType.LOCAL, + instance_type=InstanceType( + name="local", + resources=Resources(cpus=2, memory_mib=12 * 1024, gpus=[], spot=False), + ), + instance_id="0000-0000", + hostname="localhost", + region="", + price=0.0, + username="root", + ssh_port=22, + dockerized=False, + backend_data=None, + ssh_proxy=None, + ) + with patch( + "dstack._internal.server.services.jobs.configurators.base.get_default_python_verison" + ) as PyVersion: + PyVersion.return_value = "3.10" + job = await create_job( + session, + run=run, + job_provisioning_data=job_provisioning_data, + ) + await process_submitted_jobs() + await session.refresh(job) + assert job is not None + assert job.status == JobStatus.PROVISIONING + + res = await session.execute(select(JobModel).where()) + jm = res.all()[0][0] + assert jm.job_num == 0 + assert jm.run_name == "test-run" + assert jm.job_name == "test-run-0" + assert jm.submission_num == 0 + assert jm.status == JobStatus.PROVISIONING + assert jm.error_code is None + assert ( + jm.job_spec_data + == r"""{"job_num": 0, "job_name": "test-run-0", "app_specs": [], "commands": ["/bin/bash", "-i", "-c", "(echo pip install ipykernel... && pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo \"no pip, ipykernel was not installed\" && echo '' && echo To open in VS Code Desktop, use link below: && echo '' && echo ' vscode://vscode-remote/ssh-remote+test-run/workflow' && echo '' && echo 'To connect via SSH, use: `ssh test-run`' && echo '' && echo -n 'To exit, press Ctrl+C.' && tail -f /dev/null"], "env": {}, "gateway": null, "home_dir": "/root", "image_name": "dstackai/base:py3.10-0.4rc4-cuda-12.1", "max_duration": 21600, "registry_auth": null, "requirements": {"resources": {"cpu": {"min": 2, "max": null}, "memory": {"min": 8.0, "max": null}, "shm_size": null, "gpu": null, "disk": null}, "max_price": null, "spot": false}, "retry_policy": {"retry": false, "limit": null}, "working_dir": ".", "pool_name": null}""" + ) + assert jm.job_provisioning_data == ( + '{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": ' + '{"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": ' + '{"size_mib": 102400}, "description": ""}}, "instance_id": ' + '"running_instance.id", "ssh_proxy": null, ' + '"hostname": "running_instance.ip", "region": "running_instance.location", ' + '"price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, ' + '"backend_data": null}' + ) diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py index bb4173eef..89e96beed 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py @@ -5,17 +5,19 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import InstanceType, Resources -from dstack._internal.core.models.runs import JobProvisioningData, JobStatus +from dstack._internal.core.models.runs import InstanceStatus, JobProvisioningData, JobStatus from dstack._internal.server.background.tasks.process_finished_jobs import process_finished_jobs from dstack._internal.server.testing.common import ( + create_instance, create_job, + create_pool, create_project, create_repo, create_run, create_user, ) -MODULE = "dstack._internal.server.background.tasks.process_finished_jobs" +MODULE = "dstack._internal.server.services.jobs" class TestProcessFinishedJobs: @@ -33,12 +35,21 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A repo=repo, user=user, ) + instance = await create_instance( + session, + project, + await create_pool(session, project), + InstanceStatus.READY, + Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ) + job = await create_job( session=session, run=run, status=JobStatus.DONE, + instance=instance, job_provisioning_data=JobProvisioningData( - backend=BackendType.LOCAL, + backend=BackendType.AWS, instance_type=InstanceType( name="local", resources=Resources(cpus=1, memory_mib=1024, gpus=[], spot=False) ), @@ -49,11 +60,12 @@ async def test_transitions_done_jobs_marked_as_removed(self, test_db, session: A username="root", ssh_port=22, dockerized=False, + backend_data=None, + ssh_proxy=None, ), ) - with patch(f"{MODULE}.terminate_job_submission_instance") as terminate: + with patch(f"{MODULE}.terminate_job_provisioning_data_instance"): await process_finished_jobs() - terminate.assert_called_once() await session.refresh(job) assert job is not None assert job.removed diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 62122417f..30abbb495 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -16,6 +16,7 @@ InstanceType, Resources, ) +from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME from dstack._internal.core.models.runs import JobSpec, JobStatus, RunSpec from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.main import app @@ -70,12 +71,17 @@ def get_dev_env_run_plan_dict( "configuration_path": "dstack.yaml", "profile": { "backends": ["local", "aws", "azure", "gcp", "lambda"], + "creation_policy": None, "default": False, + "instance_name": None, "max_duration": "off", "max_price": None, "name": "string", + "pool_name": DEFAULT_POOL_NAME, "retry_policy": {"limit": None, "retry": False}, "spot_policy": "spot", + "termination_idle_time": 300, + "termination_policy": None, }, "repo_code_hash": None, "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, @@ -112,6 +118,7 @@ def get_dev_env_run_plan_dict( "job_name": f"{run_name}-0", "job_num": 0, "max_duration": None, + "pool_name": DEFAULT_POOL_NAME, "registry_auth": None, "requirements": { "resources": { @@ -176,12 +183,17 @@ def get_dev_env_run_dict( "configuration_path": "dstack.yaml", "profile": { "backends": ["local", "aws", "azure", "gcp", "lambda"], + "creation_policy": None, "default": False, + "instance_name": None, "max_duration": "off", "max_price": None, "name": "string", + "pool_name": DEFAULT_POOL_NAME, "retry_policy": {"limit": None, "retry": False}, "spot_policy": "spot", + "termination_idle_time": 300, + "termination_policy": None, }, "repo_code_hash": None, "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, @@ -218,6 +230,7 @@ def get_dev_env_run_dict( "job_name": f"{run_name}-0", "job_num": 0, "max_duration": None, + "pool_name": DEFAULT_POOL_NAME, "registry_auth": None, "requirements": { "resources": { @@ -255,7 +268,7 @@ def get_dev_env_run_dict( "error_code": None, "job_provisioning_data": None, }, - "cost": 0, + "cost": 0.0, "service": None, } @@ -576,7 +589,6 @@ async def test_terminates_running_run(self, test_db, session: AsyncSession): await session.refresh(job) assert job.status == JobStatus.TERMINATED assert not job.removed - assert job.remove_at is not None @pytest.mark.asyncio async def test_leaves_finished_runs_unchanged(self, test_db, session: AsyncSession): diff --git a/src/tests/_internal/server/services/test_pools.py b/src/tests/_internal/server/services/test_pools.py new file mode 100644 index 000000000..b47ce8e11 --- /dev/null +++ b/src/tests/_internal/server/services/test_pools.py @@ -0,0 +1,271 @@ +import datetime as dt +import uuid +from unittest.mock import patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +import dstack._internal.server.services.pools as services_pools +import dstack._internal.server.services.projects as services_projects +import dstack._internal.server.services.runs as runs +import dstack._internal.server.services.users as services_users +from dstack._internal.core.models import resources +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, + InstanceType, + LaunchedInstanceInfo, + Resources, + SSHKey, +) +from dstack._internal.core.models.pools import Instance, Pool +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.runs import InstanceStatus, Requirements +from dstack._internal.core.models.users import GlobalRole +from dstack._internal.server.models import InstanceModel +from dstack._internal.server.testing.common import create_project, create_user + + +@pytest.mark.asyncio +async def test_pool(session: AsyncSession, test_db): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pools.create_pool_model( + session=session, project=project, name="test_pool" + ) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + status=InstanceStatus.PENDING, + job_provisioning_data="", + offer="", + region="", + price=1, + backend=BackendType.LOCAL, + ) + session.add(im) + await session.commit() + await session.refresh(pool) + + core_model_pool = services_pools.pool_model_to_pool(pool) + assert core_model_pool == Pool( + name="test_pool", + default=True, + created_at=pool.created_at.replace(tzinfo=dt.timezone.utc), # ??? + total_instances=1, + available_instances=0, + ) + + list_pools = await services_pools.list_project_pool(session=session, project=project) + assert list_pools == [services_pools.pool_model_to_pool(pool)] + + list_pool_models = await services_pools.list_project_pool_models( + session=session, project=project + ) + assert len(list_pool_models) == 1 + + pool_intances = await services_pools.get_pool_instances(session, project, "test_pool") + assert pool_intances == [im] + + +def test_convert_instance(): + expected_instance = Instance( + backend=BackendType.LOCAL, + instance_type=InstanceType( + name="instance", resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]) + ), + name="test_instance", + hostname="hostname_test", + status=InstanceStatus.PENDING, + price=1.0, + ) + + im = InstanceModel( + id=str(uuid.uuid4()), + created_at=dt.datetime.now(), + name="test_instance", + status=InstanceStatus.PENDING, + project_id=str(uuid.uuid4()), + pool=None, + job_provisioning_data='{"ssh_proxy":null, "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + ) + + instance = services_pools.instance_model_to_instance(im) + assert instance == expected_instance + + +@pytest.mark.asyncio +async def test_delete_pool(session: AsyncSession, test_db): + POOL_NAME = "test_pool" + user = await services_users.create_user(session, "test_user", global_role=GlobalRole.ADMIN) + project = await services_projects.create_project(session, user, "test_project") + project_model = await services_projects.get_project_model_by_name_or_error( + session, project.project_name + ) + pool = await services_pools.create_pool_model(session, project_model, POOL_NAME) + + await services_pools.delete_pool(session, project_model, POOL_NAME) + + deleted_pools = await services_pools.list_deleted_pools(session, project_model) + assert len(deleted_pools) == 1 + assert pool.name == deleted_pools[0].name + + +@pytest.mark.asyncio +async def test_show_pool(session: AsyncSession, test_db): + POOL_NAME = "test_pool" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pools.create_pool_model(session=session, project=project, name=POOL_NAME) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + status=InstanceStatus.PENDING, + job_provisioning_data='{"ssh_proxy":null, "backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + region="eu-west", + price=1, + backend=BackendType.LOCAL, + ) + session.add(im) + await session.commit() + + pool_instances = await services_pools.show_pool(session, project, POOL_NAME) + assert len(pool_instances.instances) == 1 + + +@pytest.mark.asyncio +async def test_get_pool_instances(session: AsyncSession, test_db): + POOL_NAME = "test_pool" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pools.create_pool_model(session=session, project=project, name=POOL_NAME) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + status=InstanceStatus.PENDING, + job_provisioning_data='{"backend":"local","hostname":"hostname_test","region":"eu-west","price":1.0,"username":"user1","ssh_port":12345,"dockerized":false,"instance_id":"test_instance","instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + offer='{"price":"LOCAL", "price":1.0, "backend":"local", "region":"eu-west-1", "availability":"available","instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description":""}}}', + region="eu-west", + price=1, + backend=BackendType.LOCAL, + ) + session.add(im) + await session.commit() + + instances = await services_pools.get_pool_instances(session, project, POOL_NAME) + assert len(instances) == 1 + + empty_instances = await services_pools.get_pool_instances(session, project, f"{POOL_NAME}-0") + assert len(empty_instances) == 0 + + +@pytest.mark.asyncio +async def test_generate_instance_name(session: AsyncSession, test_db): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + pool = await services_pools.create_pool_model( + session=session, project=project, name="test_pool" + ) + im = InstanceModel( + name="test_instnce", + project=project, + pool=pool, + status=InstanceStatus.PENDING, + job_provisioning_data="", + offer="", + backend=BackendType.REMOTE, + region="", + price=0, + ) + session.add(im) + await session.commit() + + name = await services_pools.generate_instance_name( + session=session, project=project, pool_name="test_pool" + ) + car, _, cdr = name.partition("-") + assert len(car) > 0 + assert len(cdr) > 0 + + +@pytest.mark.asyncio +async def test_pool_double_name(session: AsyncSession, test_db): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + await services_pools.create_pool_model(session=session, project=project, name="test_pool") + with pytest.raises(ValueError): + await services_pools.create_pool_model(session=session, project=project, name="test_pool") + + +@pytest.mark.asyncio +async def test_create_cloud_instance(session: AsyncSession, test_db): + user = await create_user(session) + project = await create_project(session, user) + + profile = Profile(name="test_profile") + + requirements = Requirements(resources=resources.ResourcesSpec(cpu=1), spot=True) + + offer = InstanceOfferWithAvailability( + backend=BackendType.DATACRUNCH, + instance=InstanceType( + name="instance", resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]) + ), + region="en", + price=0.1, + availability=InstanceAvailability.AVAILABLE, + ) + + launched_instance = LaunchedInstanceInfo( + instance_id="running_instance.id", + ip_address="running_instance.ip", + region="running_instance.location", + ssh_port=22, + username="root", + dockerized=True, + backend_data=None, + ) + + class DummyBackend: + TYPE = BackendType.DATACRUNCH + + def compute(self): + return self + + def create_instance(self, *args, **kwargs): + return launched_instance + + offers = [(DummyBackend(), offer)] + + with patch("dstack._internal.server.services.runs.get_run_plan_by_requirements") as reqs: + reqs.return_value = offers + await runs.create_instance( + session, + project, + user, + profile=profile, + pool_name="test_pool", + instance_name="test_instance", + requirements=requirements, + ssh_key=SSHKey(public=""), + ) + + pool = await services_pools.get_pool(session, project, "test_pool") + assert pool is not None + instance = pool.instances[0] + + assert instance.name == "test_instance" + assert instance.deleted == False + assert instance.deleted_at is None + + # assert instance.job_provisioning_data == '{"backend": "datacrunch", "instance_type": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "instance_id": "running_instance.id", "ssh_proxy": null, "hostname": "running_instance.ip", "region": "running_instance.location", "price": 0.1, "username": "root", "ssh_port": 22, "dockerized": true, "backend_data": null}' + assert ( + instance.offer + == '{"backend": "datacrunch", "instance": {"name": "instance", "resources": {"cpus": 1, "memory_mib": 512, "gpus": [], "spot": false, "disk": {"size_mib": 102400}, "description": ""}}, "region": "en", "price": 0.1, "availability": "available"}' + )