From 714d376cbcef0aebb470351e425dbc820b35b3f0 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jun 2023 15:39:50 +0500 Subject: [PATCH 01/16] Implement LambdaLabs runner backend --- .../packer/provisioners/get-dstack-runner.sh | 0 .../ami/packer/provisioners/install-docker.sh | 0 .../internal/backend/{s3 => aws}/backend.go | 67 ++++---- runner/internal/backend/{s3 => aws}/client.go | 2 +- .../backend/{s3 => aws}/cloudwatch.go | 2 +- runner/internal/backend/{s3 => aws}/ec2.go | 2 +- runner/internal/backend/{s3 => aws}/secret.go | 2 +- runner/internal/backend/lambda/api_client.go | 41 +++++ runner/internal/backend/lambda/backend.go | 157 ++++++++++++++++++ runner/main.go | 3 +- 10 files changed, 234 insertions(+), 42 deletions(-) mode change 100644 => 100755 runner/ami/packer/provisioners/get-dstack-runner.sh mode change 100644 => 100755 runner/ami/packer/provisioners/install-docker.sh rename runner/internal/backend/{s3 => aws}/backend.go (84%) rename runner/internal/backend/{s3 => aws}/client.go (99%) rename runner/internal/backend/{s3 => aws}/cloudwatch.go (99%) rename runner/internal/backend/{s3 => aws}/ec2.go (99%) rename runner/internal/backend/{s3 => aws}/secret.go (99%) create mode 100644 runner/internal/backend/lambda/api_client.go create mode 100644 runner/internal/backend/lambda/backend.go diff --git a/runner/ami/packer/provisioners/get-dstack-runner.sh b/runner/ami/packer/provisioners/get-dstack-runner.sh old mode 100644 new mode 100755 diff --git a/runner/ami/packer/provisioners/install-docker.sh b/runner/ami/packer/provisioners/install-docker.sh old mode 100644 new mode 100755 diff --git a/runner/internal/backend/s3/backend.go b/runner/internal/backend/aws/backend.go similarity index 84% rename from runner/internal/backend/s3/backend.go rename to runner/internal/backend/aws/backend.go index b08b57ac0..259aac733 100644 --- a/runner/internal/backend/s3/backend.go +++ b/runner/internal/backend/aws/backend.go @@ -1,4 +1,4 @@ -package local +package aws import ( "context" @@ -28,7 +28,7 @@ import ( "gopkg.in/yaml.v3" ) -type S3 struct { +type AWSBackend struct { region string bucket string runnerID string @@ -62,9 +62,8 @@ func init() { }) } -func New(region, bucket string) *S3 { - - return &S3{ +func New(region, bucket string) *AWSBackend { + return &AWSBackend{ region: region, bucket: bucket, artifacts: nil, @@ -74,7 +73,7 @@ func New(region, bucket string) *S3 { } } -func (s *S3) Init(ctx context.Context, ID string) error { +func (s *AWSBackend) Init(ctx context.Context, ID string) error { log.Trace(ctx, "Initialize backend with ID runner", "runner ID", ID) if s == nil { return gerrors.New("Backend is nil") @@ -97,15 +96,10 @@ func (s *S3) Init(ctx context.Context, ID string) error { if s.state == nil { return gerrors.New("State is empty. Data not loading") } - //Update job for local runner - if s.state.Resources.Local { - s.state.Job.RequestID = fmt.Sprintf("l-%d", os.Getpid()) - s.state.Job.RunnerID = ID - } return nil } -func (s *S3) Job(ctx context.Context) *models.Job { +func (s *AWSBackend) Job(ctx context.Context) *models.Job { log.Trace(ctx, "Getting job from state") if s == nil { return new(models.Job) @@ -118,7 +112,7 @@ func (s *S3) Job(ctx context.Context) *models.Job { return s.state.Job } -func (s *S3) RefetchJob(ctx context.Context) (*models.Job, error) { +func (s *AWSBackend) RefetchJob(ctx context.Context) (*models.Job, error) { log.Trace(ctx, "Refetching job from state", "ID", s.state.Job.JobID) contents, err := s.cliS3.GetFile(ctx, s.bucket, s.state.Job.JobFilepath()) if err != nil { @@ -131,7 +125,7 @@ func (s *S3) RefetchJob(ctx context.Context) (*models.Job, error) { return s.state.Job, nil } -func (s *S3) UpdateState(ctx context.Context) error { +func (s *AWSBackend) UpdateState(ctx context.Context) error { log.Trace(ctx, "Start update state") if s == nil { return gerrors.New("Backend is nil") @@ -167,7 +161,7 @@ func (s *S3) UpdateState(ctx context.Context) error { return nil } -func (s *S3) CheckStop(ctx context.Context) (bool, error) { +func (s *AWSBackend) CheckStop(ctx context.Context) (bool, error) { if s == nil { return false, gerrors.New("Backend is nil") } @@ -188,21 +182,18 @@ func (s *S3) CheckStop(ctx context.Context) (bool, error) { return false, nil } -func (s *S3) IsInterrupted(ctx context.Context) (bool, error) { +func (s *AWSBackend) IsInterrupted(ctx context.Context) (bool, error) { if !s.state.Resources.Spot { return false, nil } return s.cliEC2.IsInterruptedSpot(ctx, s.state.RequestID) } -func (s *S3) Shutdown(ctx context.Context) error { +func (s *AWSBackend) Shutdown(ctx context.Context) error { log.Trace(ctx, "Start shutdown") if s == nil { return gerrors.New("Backend is nil") } - if s.state.Resources.Local { - return nil - } if s.state.Resources.Spot { log.Trace(ctx, "Instance interruptible") if err := s.cliEC2.CancelSpot(ctx, s.state.RequestID); err != nil { @@ -215,7 +206,7 @@ func (s *S3) Shutdown(ctx context.Context) error { } -func (s *S3) GetArtifact(ctx context.Context, runName, localPath, remotePath string, mount bool) artifacts.Artifacter { +func (s *AWSBackend) GetArtifact(ctx context.Context, runName, localPath, remotePath string, mount bool) artifacts.Artifacter { if s == nil { return nil } @@ -240,7 +231,7 @@ func (s *S3) GetArtifact(ctx context.Context, runName, localPath, remotePath str return art } -func (s *S3) GetCache(ctx context.Context, runName, localPath, remotePath string) artifacts.Artifacter { +func (s *AWSBackend) GetCache(ctx context.Context, runName, localPath, remotePath string) artifacts.Artifacter { rootPath := path.Join(s.GetTMPDir(ctx), consts.USER_ARTIFACTS_DIR, runName) art, err := simple.NewSimple(s.bucket, s.region, rootPath, localPath, remotePath, true) if err != nil { @@ -250,7 +241,7 @@ func (s *S3) GetCache(ctx context.Context, runName, localPath, remotePath string return art } -func (s *S3) Requirements(ctx context.Context) models.Requirements { +func (s *AWSBackend) Requirements(ctx context.Context) models.Requirements { if s == nil { return models.Requirements{} } @@ -262,7 +253,7 @@ func (s *S3) Requirements(ctx context.Context) models.Requirements { return s.state.Job.Requirements } -func (s *S3) MasterJob(ctx context.Context) *models.Job { +func (s *AWSBackend) MasterJob(ctx context.Context) *models.Job { if s == nil { return new(models.Job) } @@ -282,7 +273,7 @@ func (s *S3) MasterJob(ctx context.Context) *models.Job { return masterJob } -func (s *S3) CreateLogger(ctx context.Context, logGroup, logName string) io.Writer { +func (s *AWSBackend) CreateLogger(ctx context.Context, logGroup, logName string) io.Writer { if s == nil { return nil } @@ -307,7 +298,7 @@ func (s *S3) CreateLogger(ctx context.Context, logGroup, logName string) io.Writ return s.logger.Build(ctx, logGroup, logName) } -func (s *S3) ListSubDir(ctx context.Context, dir string) ([]string, error) { +func (s *AWSBackend) ListSubDir(ctx context.Context, dir string) ([]string, error) { log.Trace(ctx, "Fetching list sub dir") if s == nil { return nil, gerrors.New("Backend is nil") @@ -319,7 +310,7 @@ func (s *S3) ListSubDir(ctx context.Context, dir string) ([]string, error) { return listDir, nil } -func (s *S3) GetJobByPath(ctx context.Context, path string) (*models.Job, error) { +func (s *AWSBackend) GetJobByPath(ctx context.Context, path string) (*models.Job, error) { log.Trace(ctx, "Fetching job by path", "Path", path) if s == nil { return nil, gerrors.New("Backend is nil") @@ -336,7 +327,7 @@ func (s *S3) GetJobByPath(ctx context.Context, path string) (*models.Job, error) return job, nil } -func (s *S3) Bucket(ctx context.Context) string { +func (s *AWSBackend) Bucket(ctx context.Context) string { log.Trace(ctx, "Getting bucket") if s == nil { return "" @@ -344,7 +335,7 @@ func (s *S3) Bucket(ctx context.Context) string { return s.bucket } -func (s *S3) Secrets(ctx context.Context) (map[string]string, error) { +func (s *AWSBackend) Secrets(ctx context.Context) (map[string]string, error) { log.Trace(ctx, "Getting secrets") if s == nil { return nil, gerrors.New("Backend is nil") @@ -367,7 +358,7 @@ func (s *S3) Secrets(ctx context.Context) (map[string]string, error) { return s.cliSecret.fetchSecret(ctx, s.bucket, secrets) } -func (s *S3) GitCredentials(ctx context.Context) *models.GitCredentials { +func (s *AWSBackend) GitCredentials(ctx context.Context) *models.GitCredentials { log.Trace(ctx, "Getting credentials") if s == nil { log.Error(ctx, "Backend is empty") @@ -384,7 +375,7 @@ func (s *S3) GitCredentials(ctx context.Context) *models.GitCredentials { return s.cliSecret.fetchCredentials(ctx, s.bucket, s.state.Job.RepoId) } -func (s *S3) GetRepoDiff(ctx context.Context, path string) (string, error) { +func (s *AWSBackend) GetRepoDiff(ctx context.Context, path string) (string, error) { diff, err := s.cliS3.GetFile(ctx, s.bucket, path) if err != nil { return "", gerrors.Wrap(err) @@ -392,7 +383,7 @@ func (s *S3) GetRepoDiff(ctx context.Context, path string) (string, error) { return string(diff), nil } -func (s *S3) GetRepoArchive(ctx context.Context, path, dir string) error { +func (s *AWSBackend) GetRepoArchive(ctx context.Context, path, dir string) error { archive, err := os.CreateTemp("", "archive-*.tar") if err != nil { return gerrors.Wrap(err) @@ -408,17 +399,19 @@ func (s *S3) GetRepoArchive(ctx context.Context, path, dir string) error { } defer out.Body.Close() size, err := io.Copy(archive, out.Body) + if err != nil { + return gerrors.Wrap(err) + } if size != out.ContentLength { return gerrors.New("size not equal") } - if err := repo.ExtractArchive(ctx, archive.Name(), dir); err != nil { return gerrors.Wrap(err) } return nil } -func (s *S3) GetBuildDiff(ctx context.Context, key, dst string) error { +func (s *AWSBackend) GetBuildDiff(ctx context.Context, key, dst string) error { out, err := s.cliS3.cli.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), @@ -439,7 +432,7 @@ func (s *S3) GetBuildDiff(ctx context.Context, key, dst string) error { return nil } -func (s *S3) PutBuildDiff(ctx context.Context, src, key string) error { +func (s *AWSBackend) PutBuildDiff(ctx context.Context, src, key string) error { file, err := os.Open(src) if err != nil { return gerrors.Wrap(err) @@ -456,10 +449,10 @@ func (s *S3) PutBuildDiff(ctx context.Context, src, key string) error { return nil } -func (s *S3) GetTMPDir(ctx context.Context) string { +func (s *AWSBackend) GetTMPDir(ctx context.Context) string { return path.Join(common.HomeDir(), consts.TMP_DIR_PATH) } -func (s *S3) GetDockerBindings(ctx context.Context) []mount.Mount { +func (s *AWSBackend) GetDockerBindings(ctx context.Context) []mount.Mount { return []mount.Mount{} } diff --git a/runner/internal/backend/s3/client.go b/runner/internal/backend/aws/client.go similarity index 99% rename from runner/internal/backend/s3/client.go rename to runner/internal/backend/aws/client.go index 0532ce8c8..86dd6361f 100644 --- a/runner/internal/backend/s3/client.go +++ b/runner/internal/backend/aws/client.go @@ -1,4 +1,4 @@ -package local +package aws import ( "bytes" diff --git a/runner/internal/backend/s3/cloudwatch.go b/runner/internal/backend/aws/cloudwatch.go similarity index 99% rename from runner/internal/backend/s3/cloudwatch.go rename to runner/internal/backend/aws/cloudwatch.go index 3e0a0077c..f78b6f7d0 100644 --- a/runner/internal/backend/s3/cloudwatch.go +++ b/runner/internal/backend/aws/cloudwatch.go @@ -1,4 +1,4 @@ -package local +package aws import ( "context" diff --git a/runner/internal/backend/s3/ec2.go b/runner/internal/backend/aws/ec2.go similarity index 99% rename from runner/internal/backend/s3/ec2.go rename to runner/internal/backend/aws/ec2.go index 25412d59d..e8d77a400 100644 --- a/runner/internal/backend/s3/ec2.go +++ b/runner/internal/backend/aws/ec2.go @@ -1,4 +1,4 @@ -package local +package aws import ( "context" diff --git a/runner/internal/backend/s3/secret.go b/runner/internal/backend/aws/secret.go similarity index 99% rename from runner/internal/backend/s3/secret.go rename to runner/internal/backend/aws/secret.go index 69c9975c2..f33f0c7a0 100644 --- a/runner/internal/backend/s3/secret.go +++ b/runner/internal/backend/aws/secret.go @@ -1,4 +1,4 @@ -package local +package aws import ( "context" diff --git a/runner/internal/backend/lambda/api_client.go b/runner/internal/backend/lambda/api_client.go new file mode 100644 index 000000000..d7ac6ec7e --- /dev/null +++ b/runner/internal/backend/lambda/api_client.go @@ -0,0 +1,41 @@ +package lambda + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + + "github.com/dstackai/dstack/runner/internal/gerrors" +) + +const LAMBDA_API_URL = "https://cloud.lambdalabs.com/api/v1" + +type LambdaAPIClient struct { + apiKey string +} + +func NewLambdaAPIClient(apiKey string) *LambdaAPIClient { + return &LambdaAPIClient{apiKey: apiKey} +} + +func (client *LambdaAPIClient) TerminateInstance(ctx context.Context, instancesIDs []string) error { + body, err := json.Marshal(instancesIDs) + if err != nil { + return gerrors.Wrap(err) + } + req, err := http.NewRequest("POST", LAMBDA_API_URL+"/instance-operations/terminate", bytes.NewReader(body)) + if err != nil { + return gerrors.Wrap(err) + } + req.Header.Add("Authorization", "Bearer "+client.apiKey) + httpClient := http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return gerrors.Wrap(err) + } + if resp.StatusCode == 200 { + return nil + } + return gerrors.Newf("/instance-operations/terminate returned non-200 status code: %s", resp.Status) +} diff --git a/runner/internal/backend/lambda/backend.go b/runner/internal/backend/lambda/backend.go new file mode 100644 index 000000000..08b585688 --- /dev/null +++ b/runner/internal/backend/lambda/backend.go @@ -0,0 +1,157 @@ +package lambda + +import ( + "context" + "io" + "io/ioutil" + "os" + + "github.com/docker/docker/api/types/mount" + + "github.com/dstackai/dstack/runner/internal/artifacts" + "github.com/dstackai/dstack/runner/internal/backend" + "github.com/dstackai/dstack/runner/internal/backend/aws" + "github.com/dstackai/dstack/runner/internal/gerrors" + "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/models" + "gopkg.in/yaml.v2" +) + +type AWSCredentials struct { + AccessKey string `yaml:"access_key"` + SecretKey string `yaml:"secret_key"` +} + +type AWSStorageConfig struct { + Region string `yaml:"region"` + Bucket string `yaml:"bucket"` + Credentials AWSCredentials `yaml:"credentials"` +} + +type LambdaConfig struct { + ApiKey string `yaml:"api_key"` + StorageConfig AWSStorageConfig `yaml:"storage_config"` +} + +type LambdaBackend struct { + storageBackend *aws.AWSBackend + apiClient *LambdaAPIClient +} + +func init() { + backend.RegisterBackend("lambda", func(ctx context.Context, pathConfig string) (backend.Backend, error) { + config := LambdaConfig{} + log.Trace(ctx, "Read config file", "path", pathConfig) + fileContent, err := ioutil.ReadFile(pathConfig) + if err != nil { + return nil, gerrors.Wrap(err) + } + log.Trace(ctx, "Unmarshal config") + err = yaml.Unmarshal(fileContent, &config) + if err != nil { + return nil, gerrors.Wrap(err) + } + return New(config), nil + }) +} + +func New(config LambdaConfig) *LambdaBackend { + os.Setenv("AWS_ACCESS_KEY_ID", config.StorageConfig.Credentials.AccessKey) + os.Setenv("AWS_SECRET_ACCESS_KEY", config.StorageConfig.Credentials.SecretKey) + return &LambdaBackend{ + storageBackend: aws.New(config.StorageConfig.Region, config.StorageConfig.Bucket), + apiClient: NewLambdaAPIClient(config.ApiKey), + } +} + +func (l *LambdaBackend) Init(ctx context.Context, ID string) error { + return l.storageBackend.Init(ctx, ID) +} + +func (l *LambdaBackend) Job(ctx context.Context) *models.Job { + return l.storageBackend.Job(ctx) +} + +func (l *LambdaBackend) RefetchJob(ctx context.Context) (*models.Job, error) { + return l.storageBackend.RefetchJob(ctx) +} + +func (l *LambdaBackend) UpdateState(ctx context.Context) error { + return l.storageBackend.UpdateState(ctx) +} + +func (l *LambdaBackend) CheckStop(ctx context.Context) (bool, error) { + return l.storageBackend.CheckStop(ctx) +} + +func (l *LambdaBackend) IsInterrupted(ctx context.Context) (bool, error) { + return false, nil +} + +func (l *LambdaBackend) Shutdown(ctx context.Context) error { + return nil +} + +func (l *LambdaBackend) GetArtifact(ctx context.Context, runName, localPath, remotePath string, mount bool) artifacts.Artifacter { + return l.storageBackend.GetArtifact(ctx, runName, localPath, remotePath, mount) +} + +func (l *LambdaBackend) GetCache(ctx context.Context, runName, localPath, remotePath string) artifacts.Artifacter { + return l.storageBackend.GetCache(ctx, runName, localPath, remotePath) +} + +func (l *LambdaBackend) Requirements(ctx context.Context) models.Requirements { + return l.storageBackend.Requirements(ctx) +} + +func (l *LambdaBackend) MasterJob(ctx context.Context) *models.Job { + return l.storageBackend.MasterJob(ctx) +} + +func (l *LambdaBackend) CreateLogger(ctx context.Context, logGroup, logName string) io.Writer { + return l.storageBackend.CreateLogger(ctx, logGroup, logName) +} + +func (l *LambdaBackend) ListSubDir(ctx context.Context, dir string) ([]string, error) { + return l.storageBackend.ListSubDir(ctx, dir) +} + +func (l *LambdaBackend) GetJobByPath(ctx context.Context, path string) (*models.Job, error) { + return l.storageBackend.GetJobByPath(ctx, path) +} + +func (l *LambdaBackend) Bucket(ctx context.Context) string { + return l.storageBackend.Bucket(ctx) +} + +func (l *LambdaBackend) Secrets(ctx context.Context) (map[string]string, error) { + return l.storageBackend.Secrets(ctx) +} + +func (l *LambdaBackend) GitCredentials(ctx context.Context) *models.GitCredentials { + return l.storageBackend.GitCredentials(ctx) +} + +func (l *LambdaBackend) GetRepoDiff(ctx context.Context, path string) (string, error) { + return l.storageBackend.GetRepoDiff(ctx, path) +} + +func (l *LambdaBackend) GetRepoArchive(ctx context.Context, path, dir string) error { + return l.storageBackend.GetRepoArchive(ctx, path, dir) +} + +func (l *LambdaBackend) GetBuildDiff(ctx context.Context, key, dst string) error { + return l.storageBackend.GetBuildDiff(ctx, key, dst) +} + +func (l *LambdaBackend) PutBuildDiff(ctx context.Context, src, key string) error { + return l.storageBackend.PutBuildDiff(ctx, src, key) +} + +func (l *LambdaBackend) GetTMPDir(ctx context.Context) string { + return l.storageBackend.GetTMPDir(ctx) +} + +func (l *LambdaBackend) GetDockerBindings(ctx context.Context) []mount.Mount { + return l.storageBackend.GetDockerBindings(ctx) +} diff --git a/runner/main.go b/runner/main.go index 24eb626aa..b8dd7c0d3 100644 --- a/runner/main.go +++ b/runner/main.go @@ -21,10 +21,11 @@ import ( "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/backend" + _ "github.com/dstackai/dstack/runner/internal/backend/aws" _ "github.com/dstackai/dstack/runner/internal/backend/azure" _ "github.com/dstackai/dstack/runner/internal/backend/gcp" + _ "github.com/dstackai/dstack/runner/internal/backend/lambda" _ "github.com/dstackai/dstack/runner/internal/backend/local" - _ "github.com/dstackai/dstack/runner/internal/backend/s3" "github.com/dstackai/dstack/runner/internal/container" "github.com/dstackai/dstack/runner/internal/executor" "github.com/dstackai/dstack/runner/internal/log" From f1dfa6509565daf6cb93265c76af687153b2f10a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 23 Jun 2023 15:42:03 +0500 Subject: [PATCH 02/16] Prototype LambdaLabs server backend --- cli/dstack/_internal/backend/base/__init__.py | 4 - cli/dstack/_internal/backend/base/compute.py | 3 +- .../_internal/backend/lambdalabs/__init__.py | 274 ++++++++++++++++++ .../backend/lambdalabs/api_client.py | 55 ++++ .../_internal/backend/lambdalabs/compute.py | 148 ++++++++++ .../_internal/backend/lambdalabs/config.py | 21 ++ cli/dstack/_internal/hub/models/__init__.py | 26 +- .../hub/services/backends/__init__.py | 2 + .../services/backends/lambdalabs/__init__.py | 0 .../backends/lambdalabs/configurator.py | 52 ++++ 10 files changed, 575 insertions(+), 10 deletions(-) create mode 100644 cli/dstack/_internal/backend/lambdalabs/__init__.py create mode 100644 cli/dstack/_internal/backend/lambdalabs/api_client.py create mode 100644 cli/dstack/_internal/backend/lambdalabs/compute.py create mode 100644 cli/dstack/_internal/backend/lambdalabs/config.py create mode 100644 cli/dstack/_internal/hub/services/backends/lambdalabs/__init__.py create mode 100644 cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py diff --git a/cli/dstack/_internal/backend/base/__init__.py b/cli/dstack/_internal/backend/base/__init__.py index bb27bdd19..3ab883b57 100644 --- a/cli/dstack/_internal/backend/base/__init__.py +++ b/cli/dstack/_internal/backend/base/__init__.py @@ -21,10 +21,6 @@ class Backend(ABC): NAME = None - backend_config: BackendConfig - _storage: Storage - _compute: Compute - _secrets_manager: SecretsManager def __init__( self, diff --git a/cli/dstack/_internal/backend/base/compute.py b/cli/dstack/_internal/backend/base/compute.py index 5d79ed1af..171954edc 100644 --- a/cli/dstack/_internal/backend/base/compute.py +++ b/cli/dstack/_internal/backend/base/compute.py @@ -2,6 +2,7 @@ from functools import cmp_to_key from typing import List, Optional +from dstack._internal.core.error import DstackError from dstack._internal.core.instance import InstanceType from dstack._internal.core.job import Job, Requirements from dstack._internal.core.request import RequestHead @@ -10,7 +11,7 @@ WS_PORT = 10999 -class ComputeError(Exception): +class ComputeError(DstackError): pass diff --git a/cli/dstack/_internal/backend/lambdalabs/__init__.py b/cli/dstack/_internal/backend/lambdalabs/__init__.py new file mode 100644 index 000000000..8c9ec5313 --- /dev/null +++ b/cli/dstack/_internal/backend/lambdalabs/__init__.py @@ -0,0 +1,274 @@ +from datetime import datetime +from typing import Generator, List, Optional + +import boto3 +from botocore.client import BaseClient + +from dstack._internal.backend.aws import logs +from dstack._internal.backend.aws.config import AWSConfig +from dstack._internal.backend.aws.secrets import AWSSecretsManager +from dstack._internal.backend.aws.storage import AWSStorage +from dstack._internal.backend.base import Backend +from dstack._internal.backend.base import artifacts as base_artifacts +from dstack._internal.backend.base import cache as base_cache +from dstack._internal.backend.base import jobs as base_jobs +from dstack._internal.backend.base import repos as base_repos +from dstack._internal.backend.base import runs as base_runs +from dstack._internal.backend.base import secrets as base_secrets +from dstack._internal.backend.base import tags as base_tags +from dstack._internal.backend.lambdalabs.compute import LambdaCompute +from dstack._internal.backend.lambdalabs.config import LambdaConfig +from dstack._internal.core.artifact import Artifact +from dstack._internal.core.instance import InstanceType +from dstack._internal.core.job import Job, JobHead, JobStatus +from dstack._internal.core.log_event import LogEvent +from dstack._internal.core.repo import RemoteRepoCredentials, RepoHead, RepoSpec +from dstack._internal.core.repo.base import Repo +from dstack._internal.core.run import RunHead +from dstack._internal.core.secret import Secret +from dstack._internal.core.tag import TagHead +from dstack._internal.utils.common import PathLike + + +class LambdaBackend(Backend): + NAME = "lambda" + + def __init__( + self, + backend_config: LambdaConfig, + ): + self.backend_config = backend_config + self._compute = LambdaCompute(api_key=self.backend_config.api_key) + self._session = boto3.session.Session( + region_name=self.backend_config.storage_config.region, + aws_access_key_id=self.backend_config.storage_config.credentials.access_key, + aws_secret_access_key=self.backend_config.storage_config.credentials.secret_key, + ) + self._storage = AWSStorage( + s3_client=self._s3_client(), bucket_name=self.backend_config.storage_config.bucket + ) + self._secrets_manager = AWSSecretsManager( + secretsmanager_client=self._secretsmanager_client(), + iam_client=self._iam_client(), + sts_client=self._sts_client(), + bucket_name=self.backend_config.storage_config.bucket, + ) + + @classmethod + def load(cls) -> Optional["LambdaBackend"]: + config = AWSConfig.load() + if config is None: + return None + return cls( + backend_config=config, + ) + + def _s3_client(self) -> BaseClient: + return self._get_client("s3") + + def _ec2_client(self) -> BaseClient: + return self._get_client("ec2") + + def _iam_client(self) -> BaseClient: + return self._get_client("iam") + + def _logs_client(self) -> BaseClient: + return self._get_client("logs") + + def _secretsmanager_client(self) -> BaseClient: + return self._get_client("secretsmanager") + + def _sts_client(self) -> BaseClient: + return self._get_client("sts") + + def _get_client(self, client_name: str) -> BaseClient: + return self._session.client(client_name) + + def predict_instance_type(self, job: Job) -> Optional[InstanceType]: + return base_jobs.predict_job_instance(self._compute, job) + + def create_run(self, repo_id: str) -> str: + logs.create_log_groups_if_not_exist( + self._logs_client(), self.backend_config.storage_config.bucket, repo_id + ) + return base_runs.create_run(self._storage) + + def create_job(self, job: Job): + base_jobs.create_job(self._storage, job) + + def get_job(self, repo_id: str, job_id: str) -> Optional[Job]: + return base_jobs.get_job(self._storage, repo_id, job_id) + + def list_jobs(self, repo_id: str, run_name: str) -> List[Job]: + return base_jobs.list_jobs(self._storage, repo_id, run_name) + + def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus): + base_jobs.run_job(self._storage, self._compute, job, failed_to_start_job_new_status) + + def stop_job(self, repo_id: str, abort: bool, job_id: str): + base_jobs.stop_job(self._storage, self._compute, repo_id, job_id, abort) + + def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[JobHead]: + return base_jobs.list_job_heads(self._storage, repo_id, run_name) + + def delete_job_head(self, repo_id: str, job_id: str): + base_jobs.delete_job_head(self._storage, repo_id, job_id) + + def list_run_heads( + self, + repo_id: str, + run_name: Optional[str] = None, + include_request_heads: bool = True, + interrupted_job_new_status: JobStatus = JobStatus.FAILED, + ) -> List[RunHead]: + job_heads = self.list_job_heads(repo_id=repo_id, run_name=run_name) + return base_runs.get_run_heads( + self._storage, + self._compute, + job_heads, + include_request_heads, + interrupted_job_new_status, + ) + + def poll_logs( + self, + repo_id: str, + run_name: str, + start_time: datetime, + end_time: Optional[datetime] = None, + descending: bool = False, + diagnose: bool = False, + ) -> Generator[LogEvent, None, None]: + return logs.poll_logs( + self._storage, + self._logs_client(), + self.backend_config.storage_config.bucket, + repo_id, + run_name, + start_time, + end_time, + descending, + diagnose, + ) + + def list_run_artifact_files( + self, repo_id: str, run_name: str, prefix: str, recursive: bool = False + ) -> List[Artifact]: + return base_artifacts.list_run_artifact_files( + self._storage, repo_id, run_name, prefix, recursive + ) + + def download_run_artifact_files( + self, + repo_id: str, + run_name: str, + output_dir: Optional[PathLike], + files_path: Optional[PathLike] = None, + ): + artifacts = self.list_run_artifact_files( + repo_id, run_name=run_name, prefix="", recursive=True + ) + base_artifacts.download_run_artifact_files( + storage=self._storage, + repo_id=repo_id, + artifacts=artifacts, + output_dir=output_dir, + files_path=files_path, + ) + + def upload_job_artifact_files( + self, + repo_id: str, + job_id: str, + artifact_name: str, + artifact_path: PathLike, + local_path: PathLike, + ): + base_artifacts.upload_job_artifact_files( + storage=self._storage, + repo_id=repo_id, + job_id=job_id, + artifact_name=artifact_name, + artifact_path=artifact_path, + local_path=local_path, + ) + + def list_tag_heads(self, repo_id: str) -> List[TagHead]: + return base_tags.list_tag_heads(self._storage, repo_id) + + def get_tag_head(self, repo_id: str, tag_name: str) -> Optional[TagHead]: + return base_tags.get_tag_head(self._storage, repo_id, tag_name) + + def add_tag_from_run( + self, repo_id: str, tag_name: str, run_name: str, run_jobs: Optional[List[Job]] + ): + base_tags.create_tag_from_run( + self._storage, + repo_id, + tag_name, + run_name, + run_jobs, + ) + + def add_tag_from_local_dirs( + self, + repo: Repo, + hub_user_name: str, + tag_name: str, + local_dirs: List[str], + artifact_paths: List[str], + ): + base_tags.create_tag_from_local_dirs( + storage=self._storage, + repo=repo, + hub_user_name=hub_user_name, + tag_name=tag_name, + local_dirs=local_dirs, + artifact_paths=artifact_paths, + ) + + def delete_tag_head(self, repo_id: str, tag_head: TagHead): + base_tags.delete_tag(self._storage, repo_id, tag_head) + + def list_repo_heads(self) -> List[RepoHead]: + return base_repos.list_repo_heads(self._storage) + + def update_repo_last_run_at(self, repo_spec: RepoSpec, last_run_at: int): + base_repos.update_repo_last_run_at( + self._storage, + repo_spec, + last_run_at, + ) + + def get_repo_credentials(self, repo_id: str) -> Optional[RemoteRepoCredentials]: + return base_repos.get_repo_credentials(self._secrets_manager, repo_id) + + def save_repo_credentials(self, repo_id: str, repo_credentials: RemoteRepoCredentials): + base_repos.save_repo_credentials(self._secrets_manager, repo_id, repo_credentials) + + def delete_repo(self, repo_id: str): + base_repos.delete_repo(self._storage, repo_id) + + def list_secret_names(self, repo_id: str) -> List[str]: + return base_secrets.list_secret_names(self._storage, repo_id) + + def get_secret(self, repo_id: str, secret_name: str) -> Optional[Secret]: + return base_secrets.get_secret(self._secrets_manager, repo_id, repo_id) + + def add_secret(self, repo_id: str, secret: Secret): + base_secrets.add_secret(self._storage, self._secrets_manager, repo_id, secret) + + def update_secret(self, repo_id: str, secret: Secret): + base_secrets.update_secret(self._storage, self._secrets_manager, repo_id, secret) + + def delete_secret(self, repo_id: str, secret_name: str): + base_secrets.delete_secret(self._storage, self._secrets_manager, repo_id, repo_id) + + def get_signed_download_url(self, object_key: str) -> str: + return self._storage.get_signed_download_url(object_key) + + def get_signed_upload_url(self, object_key: str) -> str: + return self._storage.get_signed_upload_url(object_key) + + def delete_workflow_cache(self, repo_id: str, hub_user_name: str, workflow_name: str): + base_cache.delete_workflow_cache(self._storage, repo_id, hub_user_name, workflow_name) diff --git a/cli/dstack/_internal/backend/lambdalabs/api_client.py b/cli/dstack/_internal/backend/lambdalabs/api_client.py new file mode 100644 index 000000000..e93064de0 --- /dev/null +++ b/cli/dstack/_internal/backend/lambdalabs/api_client.py @@ -0,0 +1,55 @@ +from typing import Any, List, Optional + +import requests + +API_URL = "https://cloud.lambdalabs.com/api/v1" + + +class LambdaAPIClient: + def __init__(self, api_key: str): + self.api_key = api_key + + def list_instance_types(self): + resp = self._make_request("GET", "/instance-types") + if resp.ok: + return resp.json()["data"] + resp.raise_for_status() + + def list_instances(self): + resp = self._make_request("GET", "/instances") + if resp.ok: + return resp.json()["data"] + resp.raise_for_status() + + def launch_instances( + self, + region_name: str, + instance_type_name: str, + ssh_key_names: List[str], + file_system_names: List[str], + quantity: int, + name: Optional[str], + ) -> List[str]: + data = { + "region_name": region_name, + "instance_type_name": instance_type_name, + "ssh_key_names": ssh_key_names, + "file_system_names": file_system_names, + "quantity": quantity, + "name": name, + } + resp = self._make_request("POST", "/instance-operations/launch", data) + if resp.ok: + return resp.json()["data"]["instance_ids"] + resp.raise_for_status() + + def _make_request(self, method: str, path: str, data: Any = None): + return requests.request( + method=method, + url=API_URL + path, + json=data, + headers={"Authorization": f"Bearer {self.api_key}"}, + ) + + def _url(self, path: str) -> str: + return API_URL + path diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py new file mode 100644 index 000000000..31cee54a8 --- /dev/null +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -0,0 +1,148 @@ +import time +from typing import Dict, List, Optional + +from dstack._internal.backend.base.compute import choose_instance_type +from dstack._internal.backend.lambdalabs.api_client import LambdaAPIClient +from dstack._internal.core.instance import InstanceType +from dstack._internal.core.job import Job +from dstack._internal.core.request import RequestHead, RequestStatus +from dstack._internal.core.runners import Gpu, Resources + + +class LambdaCompute: + def __init__(self, api_key: str): + self.api_client = LambdaAPIClient(api_key) + + def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead: + instance_info = _get_instance_info(self.api_client, request_id) + if instance_info is None or instance_info["status"] == "terminated": + return RequestHead(job_id=job.job_id, status=RequestStatus.TERMINATED) + return RequestHead( + job_id=job.job_id, + status=RequestStatus.RUNNING, + ) + + def get_instance_type(self, job: Job) -> Optional[InstanceType]: + instance_types = _list_instance_types(self.api_client) + return choose_instance_type( + instance_types=instance_types, + requirements=job.requirements, + ) + + def run_instance(self, job: Job, instance_type: InstanceType) -> str: + return _run_instance( + api_client=self.api_client, + region_name="us-west-1", + instance_type_name=instance_type.instance_name, + ssh_key_name="dstack_victor", + instance_name=_get_instance_name(job), + ) + + def terminate_instance(self, request_id: str): + pass + + def cancel_spot_request(self, request_id: str): + pass + + +def _list_instance_types(api_client: LambdaAPIClient) -> List[InstanceType]: + instance_types_data = api_client.list_instance_types() + instance_types = [] + for instance_type_data in instance_types_data.values(): + instance_type = _instance_type_data_to_instance_type(instance_type_data) + if instance_type is not None: + instance_types.append(instance_type) + return instance_types + + +def _get_instance_info(api_client: LambdaAPIClient, instance_id: str) -> Optional[Dict]: + instances = api_client.list_instances() + instance_id_to_instance_map = {i["id"]: i for i in instances} + instance = instance_id_to_instance_map.get(instance_id) + if instance is None: + return None + return instance + + +def _instance_type_data_to_instance_type(instance_type_data: Dict) -> Optional[InstanceType]: + instance_type = instance_type_data["instance_type"] + regions = instance_type_data["regions_with_capacity_available"] + if len(regions) == 0: + return None + instance_type_specs = instance_type["specs"] + gpus = _get_instance_type_gpus(instance_type["name"]) + if gpus is None: + return None + return InstanceType( + instance_name=instance_type["name"], + resources=Resources( + cpus=instance_type_specs["vcpus"], + memory_mib=instance_type_specs["memory_gib"] * 1024, + gpus=gpus, + spot=False, + local=False, + ), + ) + + +_INSTANCE_TYPE_TO_GPU_DATA_MAP = { + "gpu_1x_a10": { + "name": "A10", + "count": 1, + "memory_mib": 24 * 1024, + }, + "gpu_1x_rtx6000": { + "name": "RTX6000", + "count": 1, + "memory_mib": 24 * 1024, + }, +} + + +def _get_instance_type_gpus(instance_type_name: str) -> Optional[List[Gpu]]: + gpu_data = _INSTANCE_TYPE_TO_GPU_DATA_MAP.get(instance_type_name) + if gpu_data is None: + return None + return [ + Gpu(name=gpu_data["name"], memory_mib=gpu_data["memory_mib"]) + for _ in range(gpu_data["count"]) + ] + + +def _get_instance_name(job: Job) -> str: + return f"dstack-{job.run_name}" + + +def _run_instance( + api_client: LambdaAPIClient, + region_name: str, + instance_type_name: str, + ssh_key_name: str, + instance_name: str, +) -> str: + instances_ids = api_client.launch_instances( + region_name=region_name, + instance_type_name=instance_type_name, + ssh_key_names=[ssh_key_name], + name=instance_name, + quantity=1, + file_system_names=[], + ) + instance_id = instances_ids[0] + instance_info = _wait_for_instance(api_client, instance_id) + return instance_id + + +def _wait_for_instance( + api_client: LambdaAPIClient, + instance_id: str, +) -> Dict: + while True: + instance_info = _get_instance_info(api_client, instance_id) + if instance_info is None or instance_info["status"] != "booting": + return + time.sleep(10) + + +def _launch_runner(hostname: str): + pass diff --git a/cli/dstack/_internal/backend/lambdalabs/config.py b/cli/dstack/_internal/backend/lambdalabs/config.py new file mode 100644 index 000000000..42cb82310 --- /dev/null +++ b/cli/dstack/_internal/backend/lambdalabs/config.py @@ -0,0 +1,21 @@ +from typing import Dict, Optional + +from pydantic import BaseModel +from typing_extensions import Literal + + +class AWSStorageConfigCredentials(BaseModel): + access_key: str + secret_key: str + + +class AWSStorageConfig(BaseModel): + bucket: str + region: str + credentials: AWSStorageConfigCredentials + + +class LambdaConfig(BaseModel): + type: Literal["lambda"] = "lambda" + api_key: str + storage_config: AWSStorageConfig diff --git a/cli/dstack/_internal/hub/models/__init__.py b/cli/dstack/_internal/hub/models/__init__.py index 9cb5adf60..d2b5ddc3e 100644 --- a/cli/dstack/_internal/hub/models/__init__.py +++ b/cli/dstack/_internal/hub/models/__init__.py @@ -30,7 +30,9 @@ class Member(BaseModel): project_role: ProjectRole -BackendType = Union[Literal["local"], Literal["aws"], Literal["gcp"], Literal["azure"]] +BackendType = Union[ + Literal["local"], Literal["aws"], Literal["gcp"], Literal["azure"], Literal["lambda"] +] class LocalProjectConfig(BaseModel): @@ -162,20 +164,30 @@ class AzureProjectConfigWithCreds(AzureProjectConfig): credentials: AzureProjectCreds +class LambdaProjectConfig(BaseModel): + type: Literal["lambda"] = "lambda" + + +class LambdaProjectConfigWithCreds(LambdaProjectConfig): + api_key: str + + AnyProjectConfig = Union[ - LocalProjectConfig, AWSProjectConfig, GCPProjectConfig, AzureProjectConfig + LocalProjectConfig, AWSProjectConfig, GCPProjectConfig, AzureProjectConfig, LambdaProjectConfig ] AnyProjectConfigWithCredsPartial = Union[ LocalProjectConfig, AWSProjectConfigWithCredsPartial, GCPProjectConfigWithCredsPartial, AzureProjectConfigWithCredsPartial, + LambdaProjectConfig, ] AnyProjectConfigWithCreds = Union[ LocalProjectConfig, AWSProjectConfigWithCreds, GCPProjectConfigWithCreds, AzureProjectConfigWithCreds, + LambdaProjectConfigWithCreds, ] @@ -366,10 +378,14 @@ class AzureProjectValues(BaseModel): storage_account: Optional[ProjectElement] +class LambdaProjectValues(BaseModel): + type: Literal["lambda"] = "lambda" + + class ProjectValues(BaseModel): - __root__: Union[None, AWSProjectValues, GCPProjectValues, AzureProjectValues] = Field( - ..., discriminator="type" - ) + __root__: Union[ + None, AWSProjectValues, GCPProjectValues, AzureProjectValues, LambdaProjectValues + ] = Field(..., discriminator="type") class UserPatch(BaseModel): diff --git a/cli/dstack/_internal/hub/services/backends/__init__.py b/cli/dstack/_internal/hub/services/backends/__init__.py index 2e9980fad..1d68d107b 100644 --- a/cli/dstack/_internal/hub/services/backends/__init__.py +++ b/cli/dstack/_internal/hub/services/backends/__init__.py @@ -3,6 +3,7 @@ from dstack._internal.hub.models import BackendType from dstack._internal.hub.services.backends.base import Configurator +from dstack._internal.hub.services.backends.lambdalabs.configurator import LambdaConfigurator from dstack._internal.hub.services.backends.local.configurator import LocalConfigurator configurators_classes = [] @@ -11,6 +12,7 @@ from dstack._internal.hub.services.backends.aws.configurator import AWSConfigurator configurators_classes.append(AWSConfigurator) + configurators_classes.append(LambdaConfigurator) except ImportError: pass diff --git a/cli/dstack/_internal/hub/services/backends/lambdalabs/__init__.py b/cli/dstack/_internal/hub/services/backends/lambdalabs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py b/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py new file mode 100644 index 000000000..9bf7087a7 --- /dev/null +++ b/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py @@ -0,0 +1,52 @@ +import os +from typing import Dict, Tuple, Union + +from dstack._internal.backend.lambdalabs import LambdaBackend +from dstack._internal.backend.lambdalabs.config import ( + AWSStorageConfig, + AWSStorageConfigCredentials, + LambdaConfig, +) +from dstack._internal.hub.db.models import Project +from dstack._internal.hub.models import ( + LambdaProjectConfig, + LambdaProjectConfigWithCreds, + ProjectValues, +) + + +class LambdaConfigurator: + NAME = "lambda" + + def get_backend_class(self) -> type: + return LambdaBackend + + def configure_project(self, config_data: Dict) -> ProjectValues: + return None + + def create_config_auth_data_from_project_config( + self, project_config: LambdaProjectConfigWithCreds + ) -> Tuple[Dict, Dict]: + config_data = LambdaProjectConfig.parse_obj(project_config).dict() + auth_data = {"api_key": project_config.api_key} + return config_data, auth_data + + def get_project_config_from_project( + self, project: Project, include_creds: bool + ) -> Union[LambdaProjectConfig, LambdaProjectConfigWithCreds]: + return LambdaProjectConfig() + + def get_backend_config_from_hub_config_data( + self, project_name: str, config_data: Dict, auth_data: Dict + ) -> LambdaConfig: + return LambdaConfig( + api_key=os.environ["LAMBDA_API_KEY"], + storage_config=AWSStorageConfig( + region="eu-west-1", + bucket="dstack-lambda-eu-west-1", + credentials=AWSStorageConfigCredentials( + access_key=os.environ["AWS_ACCESS_KEY_ID"], + secret_key=os.environ["AWS_SECRET_ACCESS_KEY"], + ), + ), + ) From c7310567969e699f43f41905f3fa3ab73da7efaa Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 26 Jun 2023 14:48:37 +0500 Subject: [PATCH 03/16] Set up lambda instances --- .../_internal/backend/lambdalabs/__init__.py | 2 +- .../backend/lambdalabs/api_client.py | 25 ++- .../_internal/backend/lambdalabs/compute.py | 142 ++++++++++++++++-- cli/dstack/_internal/hub/main.py | 2 + cli/dstack/_internal/hub/utils/ssh.py | 20 +++ cli/dstack/_internal/scripts/setup_lambda.sh | 21 +++ cli/dstack/_internal/utils/crypto.py | 36 +++++ .../packer/provisioners/get-dstack-runner.sh | 17 ++- setup.py | 1 + 9 files changed, 243 insertions(+), 23 deletions(-) create mode 100644 cli/dstack/_internal/hub/utils/ssh.py create mode 100644 cli/dstack/_internal/scripts/setup_lambda.sh create mode 100644 cli/dstack/_internal/utils/crypto.py diff --git a/cli/dstack/_internal/backend/lambdalabs/__init__.py b/cli/dstack/_internal/backend/lambdalabs/__init__.py index 8c9ec5313..ddeffb0d3 100644 --- a/cli/dstack/_internal/backend/lambdalabs/__init__.py +++ b/cli/dstack/_internal/backend/lambdalabs/__init__.py @@ -38,7 +38,7 @@ def __init__( backend_config: LambdaConfig, ): self.backend_config = backend_config - self._compute = LambdaCompute(api_key=self.backend_config.api_key) + self._compute = LambdaCompute(lambda_config=self.backend_config) self._session = boto3.session.Session( region_name=self.backend_config.storage_config.region, aws_access_key_id=self.backend_config.storage_config.credentials.access_key, diff --git a/cli/dstack/_internal/backend/lambdalabs/api_client.py b/cli/dstack/_internal/backend/lambdalabs/api_client.py index e93064de0..fc28cdedd 100644 --- a/cli/dstack/_internal/backend/lambdalabs/api_client.py +++ b/cli/dstack/_internal/backend/lambdalabs/api_client.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import requests @@ -43,6 +43,29 @@ def launch_instances( return resp.json()["data"]["instance_ids"] resp.raise_for_status() + def terminate_instances(self, instance_ids: List[str]) -> List[str]: + data = {"instance_ids": instance_ids} + resp = self._make_request("POST", "/instance-operations/terminate", data) + if resp.ok: + return resp.json()["data"] + resp.raise_for_status() + + def list_ssh_keys(self) -> List[Dict]: + resp = self._make_request("GET", "/ssh-keys") + if resp.ok: + return resp.json()["data"] + resp.raise_for_status() + + def add_ssh_key(self, name: str, public_key: str) -> List[Dict]: + data = { + "name": name, + "public_key": public_key, + } + resp = self._make_request("POST", "/ssh-keys", data) + if resp.ok: + return resp.json()["data"] + resp.raise_for_status() + def _make_request(self, method: str, path: str, data: Any = None): return requests.request( method=method, diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py index 31cee54a8..5b7831da2 100644 --- a/cli/dstack/_internal/backend/lambdalabs/compute.py +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -1,17 +1,32 @@ +import hashlib +import os +import subprocess +import tempfile import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple -from dstack._internal.backend.base.compute import choose_instance_type +import pkg_resources +import yaml + +from dstack._internal.backend.base.compute import WS_PORT, choose_instance_type +from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME +from dstack._internal.backend.base.runners import serialize_runner_yaml from dstack._internal.backend.lambdalabs.api_client import LambdaAPIClient +from dstack._internal.backend.lambdalabs.config import LambdaConfig from dstack._internal.core.instance import InstanceType from dstack._internal.core.job import Job from dstack._internal.core.request import RequestHead, RequestStatus from dstack._internal.core.runners import Gpu, Resources +from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH, get_hub_ssh_public_key + +_WAIT_FOR_INSTANCE_ATTEMPTS = 120 +_WAIT_FOR_INSTANCE_INTERVAL = 10 class LambdaCompute: - def __init__(self, api_key: str): - self.api_client = LambdaAPIClient(api_key) + def __init__(self, lambda_config: LambdaConfig): + self.lambda_config = lambda_config + self.api_client = LambdaAPIClient(lambda_config.api_key) def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead: instance_info = _get_instance_info(self.api_client, request_id) @@ -34,12 +49,14 @@ def run_instance(self, job: Job, instance_type: InstanceType) -> str: api_client=self.api_client, region_name="us-west-1", instance_type_name=instance_type.instance_name, - ssh_key_name="dstack_victor", + user_ssh_key=job.ssh_key_pub, + hub_ssh_key=get_hub_ssh_public_key(), instance_name=_get_instance_name(job), + launch_script=_get_launch_script(self.lambda_config, job, instance_type), ) def terminate_instance(self, request_id: str): - pass + self.api_client.terminate_instances(instance_ids=[request_id]) def cancel_spot_request(self, request_id: str): pass @@ -113,36 +130,135 @@ def _get_instance_name(job: Job) -> str: return f"dstack-{job.run_name}" +def _get_launch_script(lambda_config: LambdaConfig, job: Job, instance_type: InstanceType) -> str: + config_content = yaml.dump(lambda_config.dict()).replace("\n", "\\n") + runner_content = serialize_runner_yaml(job.runner_id, instance_type.resources, 3000, 4000) + return f"""#!/bin/sh +mkdir -p /root/.dstack/ +echo '{config_content}' > /root/.dstack/{BACKEND_CONFIG_FILENAME} +echo '{runner_content}' > /root/.dstack/{RUNNER_CONFIG_FILENAME} +echo 'hostname: HOSTNAME_PLACEHOLDER' >> /root/.dstack/{RUNNER_CONFIG_FILENAME} +HOME=/root nohup dstack-runner --log-level 6 start --http-port {WS_PORT} +""" + + def _run_instance( api_client: LambdaAPIClient, region_name: str, instance_type_name: str, - ssh_key_name: str, + user_ssh_key: str, + hub_ssh_key: str, instance_name: str, + launch_script: str, ) -> str: + _, hub_key_name = _add_ssh_keys(api_client, user_ssh_key, hub_ssh_key) instances_ids = api_client.launch_instances( region_name=region_name, instance_type_name=instance_type_name, - ssh_key_names=[ssh_key_name], + ssh_key_names=[hub_key_name], name=instance_name, quantity=1, file_system_names=[], ) instance_id = instances_ids[0] instance_info = _wait_for_instance(api_client, instance_id) + hostname = instance_info["ip"] + _setup_instance(hostname=hostname, user_ssh_key=user_ssh_key) + _launch_runner(hostname=hostname, launch_script=launch_script) return instance_id +def _add_ssh_keys( + api_client: LambdaAPIClient, user_ssh_key: str, hub_ssh_key: str +) -> Tuple[str, str]: + ssh_keys = api_client.list_ssh_keys() + ssh_key_names = [k["name"] for k in ssh_keys] + user_key_name = _add_ssh_key(api_client, ssh_key_names, user_ssh_key) + hub_key_name = _add_ssh_key(api_client, ssh_key_names, hub_ssh_key) + return user_key_name, hub_key_name + + +def _add_ssh_key(api_client: LambdaAPIClient, ssh_key_names: str, public_key: str) -> str: + key_name = _get_ssh_key_name(public_key) + if key_name in ssh_key_names: + return key_name + api_client.add_ssh_key(name=key_name, public_key=public_key) + return key_name + + +def _get_ssh_key_name(public_key: str) -> str: + return hashlib.sha1(public_key.encode()).hexdigest()[-16:] + + def _wait_for_instance( api_client: LambdaAPIClient, instance_id: str, ) -> Dict: - while True: + for _ in range(_WAIT_FOR_INSTANCE_ATTEMPTS): instance_info = _get_instance_info(api_client, instance_id) if instance_info is None or instance_info["status"] != "booting": - return - time.sleep(10) + return instance_info + time.sleep(_WAIT_FOR_INSTANCE_INTERVAL) -def _launch_runner(hostname: str): - pass +def _setup_instance(hostname: str, user_ssh_key: str): + setup_script_path = pkg_resources.resource_filename( + "dstack._internal", "scripts/setup_lambda.sh" + ) + _run_ssh_command(hostname=hostname, commands="mkdir /home/ubuntu/.dstack") + _run_scp_command(hostname=hostname, source=setup_script_path, target="/home/ubuntu/.dstack") + # Lambda API allows specifying only one ssh key, + # so we have to update authorized_keys manually to add the user key + setup_commands = ( + "chmod +x .dstack/setup_lambda.sh && " + ".dstack/setup_lambda.sh && " + f"echo '{user_ssh_key}' >> /home/ubuntu/.ssh/authorized_keys" + ) + _run_ssh_command(hostname=hostname, commands=setup_commands) + + +def _launch_runner(hostname: str, launch_script: str): + launch_script = launch_script.replace("HOSTNAME_PLACEHOLDER", hostname) + with tempfile.NamedTemporaryFile("w+") as f: + f.write(launch_script) + f.flush() + filepath = os.path.join(tempfile.gettempdir(), f.name) + _run_scp_command( + hostname=hostname, source=filepath, target="/home/ubuntu/.dstack/launch_runner.sh" + ) + _run_ssh_command( + hostname=hostname, + commands="chmod +x .dstack/launch_runner.sh && sudo .dstack/launch_runner.sh", + ) + + +def _run_ssh_command(hostname: str, commands: str): + subprocess.run( + [ + "ssh", + "-o", + "StrictHostKeyChecking=no", + "-i", + HUB_PRIVATE_KEY_PATH, + f"ubuntu@{hostname}", + commands, + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + +def _run_scp_command(hostname: str, source: str, target: str): + subprocess.run( + [ + "scp", + "-o", + "StrictHostKeyChecking=no", + "-i", + HUB_PRIVATE_KEY_PATH, + source, + f"ubuntu@{hostname}:{target}", + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) diff --git a/cli/dstack/_internal/hub/main.py b/cli/dstack/_internal/hub/main.py index fc9e1a02a..400072473 100644 --- a/cli/dstack/_internal/hub/main.py +++ b/cli/dstack/_internal/hub/main.py @@ -31,6 +31,7 @@ ) from dstack._internal.hub.services.backends import local_backend_available from dstack._internal.hub.utils import logging +from dstack._internal.hub.utils.ssh import generate_hub_ssh_key_pair logging.configure_root_logger() logger = logging.get_logger(__name__) @@ -65,6 +66,7 @@ async def startup_event(): url_with_token = f"{url}?token={admin_user.token}" create_default_project_config(url, admin_user.token) start_background_tasks() + generate_hub_ssh_key_pair() print(f"The server is available at {url_with_token}") diff --git a/cli/dstack/_internal/hub/utils/ssh.py b/cli/dstack/_internal/hub/utils/ssh.py new file mode 100644 index 000000000..f9d39b2ed --- /dev/null +++ b/cli/dstack/_internal/hub/utils/ssh.py @@ -0,0 +1,20 @@ +from pathlib import Path + +from dstack._internal.utils.crypto import generage_rsa_key_pair + +HUB_PRIVATE_KEY_PATH = Path.home() / ".dstack" / "hub" / "ssh" / "hub_ssh_key" +HUB_PUBLIC_KEY_PATH = Path.home() / ".dstack" / "hub" / "ssh" / "hub_ssh_key.pub" + + +def generate_hub_ssh_key_pair(): + if HUB_PRIVATE_KEY_PATH.exists(): + return + HUB_PRIVATE_KEY_PATH.parent.mkdir(parents=True, exist_ok=True) + generage_rsa_key_pair( + private_key_path=HUB_PRIVATE_KEY_PATH, public_key_path=HUB_PUBLIC_KEY_PATH + ) + + +def get_hub_ssh_public_key() -> str: + with open(HUB_PUBLIC_KEY_PATH) as f: + return f.read() diff --git a/cli/dstack/_internal/scripts/setup_lambda.sh b/cli/dstack/_internal/scripts/setup_lambda.sh new file mode 100644 index 000000000..532449075 --- /dev/null +++ b/cli/dstack/_internal/scripts/setup_lambda.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e +RUNNER_VERSION=${RUNNER_VERSION:-latest} + +function install_stgn { + sudo curl --output /usr/local/bin/dstack-runner "https://dstack-runner-downloads-stgn.s3.eu-west-1.amazonaws.com/${RUNNER_VERSION}/binaries/dstack-runner-linux-amd64" + sudo chmod +x /usr/local/bin/dstack-runner + dstack-runner --version +} + +function install_prod { + sudo curl --output /usr/local/bin/dstack-runner "https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/latest/binaries/dstack-runner-linux-amd64" + sudo chmod +x /usr/local/bin/dstack-runner + dstack-runner --version +} + +if [[ $DSTACK_STAGE == "PROD" ]]; then + install_prod +else + install_stgn +fi diff --git a/cli/dstack/_internal/utils/crypto.py b/cli/dstack/_internal/utils/crypto.py new file mode 100644 index 000000000..fce579abd --- /dev/null +++ b/cli/dstack/_internal/utils/crypto.py @@ -0,0 +1,36 @@ +import os +from pathlib import Path +from typing import Optional + +from cryptography.hazmat.backends import default_backend as crypto_default_backend +from cryptography.hazmat.primitives import serialization as crypto_serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +def generage_rsa_key_pair(private_key_path: Path, public_key_path: Optional[Path] = None): + if public_key_path is None: + public_key_path = private_key_path.with_suffix(private_key_path.suffix + ".pub") + + key = rsa.generate_private_key( + backend=crypto_default_backend(), public_exponent=65537, key_size=2048 + ) + + def key_opener(path, flags): + return os.open(path, flags, 0o600) + + with open(private_key_path, "wb", opener=key_opener) as f: + f.write( + key.private_bytes( + crypto_serialization.Encoding.PEM, + crypto_serialization.PrivateFormat.PKCS8, + crypto_serialization.NoEncryption(), + ) + ) + with open(public_key_path, "wb", opener=key_opener) as f: + f.write( + key.public_key().public_bytes( + crypto_serialization.Encoding.OpenSSH, + crypto_serialization.PublicFormat.OpenSSH, + ) + ) + f.write(b" dstack\n") diff --git a/runner/ami/packer/provisioners/get-dstack-runner.sh b/runner/ami/packer/provisioners/get-dstack-runner.sh index a71d53898..406096860 100755 --- a/runner/ami/packer/provisioners/get-dstack-runner.sh +++ b/runner/ami/packer/provisioners/get-dstack-runner.sh @@ -3,15 +3,16 @@ set -e RUNNER_VERSION=${RUNNER_VERSION:-latest} function install_fuse { -sudo apt install s3fs -y -if [ -e "/etc/fuse.conf" ]; then - echo "THIS /etc/fuse.conf" - sudo sed "s/# *user_allow_other/user_allow_other/" /etc/fuse.conf > t - sudo mv t /etc/fuse.conf -else - echo "user_allow_other" | tee -a /etc/fuse.conf > /dev/null -fi + sudo apt install s3fs -y + if [ -e "/etc/fuse.conf" ]; then + echo "THIS /etc/fuse.conf" + sudo sed "s/# *user_allow_other/user_allow_other/" /etc/fuse.conf > t + sudo mv t /etc/fuse.conf + else + echo "user_allow_other" | tee -a /etc/fuse.conf > /dev/null + fi } + function install_stgn { sudo curl --output /usr/local/bin/dstack-runner "https://dstack-runner-downloads-stgn.s3.eu-west-1.amazonaws.com/${RUNNER_VERSION}/binaries/dstack-runner-linux-amd64" sudo chmod +x /usr/local/bin/dstack-runner diff --git a/setup.py b/setup.py index 561007412..1b636e905 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ def get_long_description(): packages=find_packages("cli"), package_data={ "dstack._internal.schemas": ["*.json"], + "dstack._internal.scripts": ["*.sh"], "dstack._internal.hub": [ "statics/*", "statics/**/*", From 4ae9ec5e78910dcff040d483abcce4b82d924a1b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 26 Jun 2023 16:16:24 +0500 Subject: [PATCH 04/16] Install docker nvidia runtime on lambda --- cli/dstack/_internal/backend/lambdalabs/compute.py | 2 +- cli/dstack/_internal/scripts/setup_lambda.sh | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py index 5b7831da2..3a83e849f 100644 --- a/cli/dstack/_internal/backend/lambdalabs/compute.py +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -8,7 +8,7 @@ import pkg_resources import yaml -from dstack._internal.backend.base.compute import WS_PORT, choose_instance_type +from dstack._internal.backend.base.compute import WS_PORT, NoCapacityError, choose_instance_type from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml from dstack._internal.backend.lambdalabs.api_client import LambdaAPIClient diff --git a/cli/dstack/_internal/scripts/setup_lambda.sh b/cli/dstack/_internal/scripts/setup_lambda.sh index 532449075..a49eb726d 100644 --- a/cli/dstack/_internal/scripts/setup_lambda.sh +++ b/cli/dstack/_internal/scripts/setup_lambda.sh @@ -2,6 +2,11 @@ set -e RUNNER_VERSION=${RUNNER_VERSION:-latest} +function install_nvidia_docker_runtime { + sudo apt-get update + sudo apt-get install -y --no-install-recommends nvidia-docker2 +} + function install_stgn { sudo curl --output /usr/local/bin/dstack-runner "https://dstack-runner-downloads-stgn.s3.eu-west-1.amazonaws.com/${RUNNER_VERSION}/binaries/dstack-runner-linux-amd64" sudo chmod +x /usr/local/bin/dstack-runner @@ -14,6 +19,8 @@ function install_prod { dstack-runner --version } +install_nvidia_docker_runtime + if [[ $DSTACK_STAGE == "PROD" ]]; then install_prod else From a75cb2bfdcc547c02b8575b2e95fd57b72207a76 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 26 Jun 2023 17:56:45 +0500 Subject: [PATCH 05/16] Fix lambda provisioning --- .../_internal/backend/lambdalabs/compute.py | 3 ++- .../_internal/backend/lambdalabs/config.py | 2 +- .../_internal/cli/commands/run/__init__.py | 2 +- cli/dstack/_internal/scripts/setup_lambda.sh | 1 + runner/internal/models/backend.go | 18 +++++++++++++----- 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py index 3a83e849f..22984c1a0 100644 --- a/cli/dstack/_internal/backend/lambdalabs/compute.py +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -161,6 +161,7 @@ def _run_instance( file_system_names=[], ) instance_id = instances_ids[0] + # instance_id = api_client.list_instances()[0]["id"] instance_info = _wait_for_instance(api_client, instance_id) hostname = instance_info["ip"] _setup_instance(hostname=hostname, user_ssh_key=user_ssh_key) @@ -211,7 +212,7 @@ def _setup_instance(hostname: str, user_ssh_key: str): # so we have to update authorized_keys manually to add the user key setup_commands = ( "chmod +x .dstack/setup_lambda.sh && " - ".dstack/setup_lambda.sh && " + "RUNNER=1349 .dstack/setup_lambda.sh && " f"echo '{user_ssh_key}' >> /home/ubuntu/.ssh/authorized_keys" ) _run_ssh_command(hostname=hostname, commands=setup_commands) diff --git a/cli/dstack/_internal/backend/lambdalabs/config.py b/cli/dstack/_internal/backend/lambdalabs/config.py index 42cb82310..e47688d71 100644 --- a/cli/dstack/_internal/backend/lambdalabs/config.py +++ b/cli/dstack/_internal/backend/lambdalabs/config.py @@ -16,6 +16,6 @@ class AWSStorageConfig(BaseModel): class LambdaConfig(BaseModel): - type: Literal["lambda"] = "lambda" + backend: Literal["lambda"] = "lambda" api_key: str storage_config: AWSStorageConfig diff --git a/cli/dstack/_internal/cli/commands/run/__init__.py b/cli/dstack/_internal/cli/commands/run/__init__.py index 1f07e9acd..3e33a4989 100644 --- a/cli/dstack/_internal/cli/commands/run/__init__.py +++ b/cli/dstack/_internal/cli/commands/run/__init__.py @@ -360,7 +360,7 @@ def _attach(hub_client: HubClient, job: Job, ssh_key_path: str) -> Dict[int, int { "HostName": job.host_name, # TODO: use non-root for all backends - "User": "ubuntu" if backend_type in ("azure", "gcp") else "root", + "User": "ubuntu" if backend_type in ("azure", "gcp", "lambda") else "root", "IdentityFile": ssh_key_path, "StrictHostKeyChecking": "no", "UserKnownHostsFile": "/dev/null", diff --git a/cli/dstack/_internal/scripts/setup_lambda.sh b/cli/dstack/_internal/scripts/setup_lambda.sh index a49eb726d..3c3061386 100644 --- a/cli/dstack/_internal/scripts/setup_lambda.sh +++ b/cli/dstack/_internal/scripts/setup_lambda.sh @@ -5,6 +5,7 @@ RUNNER_VERSION=${RUNNER_VERSION:-latest} function install_nvidia_docker_runtime { sudo apt-get update sudo apt-get install -y --no-install-recommends nvidia-docker2 + sudo pkill -SIGHUP dockerd } function install_stgn { diff --git a/runner/internal/models/backend.go b/runner/internal/models/backend.go index fdb03ce19..4cebfcb21 100644 --- a/runner/internal/models/backend.go +++ b/runner/internal/models/backend.go @@ -163,15 +163,15 @@ func (j *Job) JobHeadFilepathPrefix() string { func (j *Job) JobHeadFilepath() string { appsSlice := make([]string, len(j.Apps)) - for _, app := range j.Apps { - appsSlice = append(appsSlice, app.Name) + for i, app := range j.Apps { + appsSlice[i] = app.Name } artifactSlice := make([]string, len(j.Artifacts)) - for _, art := range j.Artifacts { - artifactSlice = append(artifactSlice, EscapeHead(art.Path)) + for i, art := range j.Artifacts { + artifactSlice[i] = EscapeHead(art.Path) } return fmt.Sprintf( - "jobs/%s/l;%s;%s;%s;%d;%s;%s;%s;%s;%s;%s", + "jobs/%s/l;%s;%s;%s;%d;%s;%s;%s;%s;%s;%s;%s", j.RepoId, j.JobID, j.ProviderName, @@ -183,9 +183,17 @@ func (j *Job) JobHeadFilepath() string { j.TagName, j.InstanceType, EscapeHead(j.ConfigurationPath), + j.GetInstanceType(), ) } +func (j *Job) GetInstanceType() string { + if j.Requirements.Spot { + return "spot" + } + return "on-demand" +} + func (j *Job) SecretsPrefix() string { return fmt.Sprintf("secrets/%s/l;", j.RepoId) } From c97908ae96d272accd91ed83c1cd70b47f3fac97 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 27 Jun 2023 10:32:15 +0500 Subject: [PATCH 06/16] Terminate lambda instance from runner --- runner/internal/backend/aws/backend.go | 68 ++++++++++---------- runner/internal/backend/lambda/api_client.go | 8 ++- runner/internal/backend/lambda/backend.go | 2 +- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/runner/internal/backend/aws/backend.go b/runner/internal/backend/aws/backend.go index 259aac733..ac08cb05c 100644 --- a/runner/internal/backend/aws/backend.go +++ b/runner/internal/backend/aws/backend.go @@ -29,10 +29,10 @@ import ( ) type AWSBackend struct { + State *models.State region string bucket string runnerID string - state *models.State artifacts []artifacts.Artifacter cliS3 *ClientS3 cliEC2 *ClientEC2 @@ -89,11 +89,11 @@ func (s *AWSBackend) Init(ctx context.Context, ID string) error { return gerrors.Wrap(err) } log.Trace(ctx, "Unmarshal state") - err = yaml.Unmarshal(theFile, &s.state) + err = yaml.Unmarshal(theFile, &s.State) if err != nil { return gerrors.Wrap(err) } - if s.state == nil { + if s.State == nil { return gerrors.New("State is empty. Data not loading") } return nil @@ -104,25 +104,25 @@ func (s *AWSBackend) Job(ctx context.Context) *models.Job { if s == nil { return new(models.Job) } - if s.state == nil { + if s.State == nil { log.Trace(ctx, "State not exist") return new(models.Job) } - log.Trace(ctx, "Get job", "job ID", s.state.Job.JobID) - return s.state.Job + log.Trace(ctx, "Get job", "job ID", s.State.Job.JobID) + return s.State.Job } func (s *AWSBackend) RefetchJob(ctx context.Context) (*models.Job, error) { - log.Trace(ctx, "Refetching job from state", "ID", s.state.Job.JobID) - contents, err := s.cliS3.GetFile(ctx, s.bucket, s.state.Job.JobFilepath()) + log.Trace(ctx, "Refetching job from state", "ID", s.State.Job.JobID) + contents, err := s.cliS3.GetFile(ctx, s.bucket, s.State.Job.JobFilepath()) if err != nil { return nil, gerrors.Wrap(err) } - err = yaml.Unmarshal(contents, &s.state.Job) + err = yaml.Unmarshal(contents, &s.State.Job) if err != nil { return nil, gerrors.Wrap(err) } - return s.state.Job, nil + return s.State.Job, nil } func (s *AWSBackend) UpdateState(ctx context.Context) error { @@ -130,27 +130,27 @@ func (s *AWSBackend) UpdateState(ctx context.Context) error { if s == nil { return gerrors.New("Backend is nil") } - if s.state == nil { + if s.State == nil { log.Trace(ctx, "State not exist") return gerrors.Wrap(backend.ErrLoadStateFile) } log.Trace(ctx, "Marshaling job") - theFile, err := yaml.Marshal(&s.state.Job) + theFile, err := yaml.Marshal(&s.State.Job) if err != nil { return gerrors.Wrap(err) } - jobFilepath := s.state.Job.JobFilepath() + jobFilepath := s.State.Job.JobFilepath() log.Trace(ctx, "Write to file job", "Path", jobFilepath) err = s.cliS3.PutFile(ctx, s.bucket, jobFilepath, theFile) if err != nil { return gerrors.Wrap(err) } - log.Trace(ctx, "Fetching list jobs", "Repo username", s.state.Job.RepoUserName, "Repo name", s.state.Job.RepoName, "Job ID", s.state.Job.JobID) - files, err := s.cliS3.ListFile(ctx, s.bucket, s.state.Job.JobHeadFilepathPrefix()) + log.Trace(ctx, "Fetching list jobs", "Repo username", s.State.Job.RepoUserName, "Repo name", s.State.Job.RepoName, "Job ID", s.State.Job.JobID) + files, err := s.cliS3.ListFile(ctx, s.bucket, s.State.Job.JobHeadFilepathPrefix()) if err != nil { return gerrors.Wrap(err) } - jobHeadFilepath := s.state.Job.JobHeadFilepath() + jobHeadFilepath := s.State.Job.JobHeadFilepath() for _, file := range files[:1] { log.Trace(ctx, "Renaming file job", "From", file, "To", jobHeadFilepath) err = s.cliS3.RenameFile(ctx, s.bucket, file, jobHeadFilepath) @@ -165,7 +165,7 @@ func (s *AWSBackend) CheckStop(ctx context.Context) (bool, error) { if s == nil { return false, gerrors.New("Backend is nil") } - if s.state == nil { + if s.State == nil { log.Trace(ctx, "State not exist") return false, gerrors.Wrap(backend.ErrLoadStateFile) } @@ -183,10 +183,10 @@ func (s *AWSBackend) CheckStop(ctx context.Context) (bool, error) { } func (s *AWSBackend) IsInterrupted(ctx context.Context) (bool, error) { - if !s.state.Resources.Spot { + if !s.State.Resources.Spot { return false, nil } - return s.cliEC2.IsInterruptedSpot(ctx, s.state.RequestID) + return s.cliEC2.IsInterruptedSpot(ctx, s.State.RequestID) } func (s *AWSBackend) Shutdown(ctx context.Context) error { @@ -194,15 +194,15 @@ func (s *AWSBackend) Shutdown(ctx context.Context) error { if s == nil { return gerrors.New("Backend is nil") } - if s.state.Resources.Spot { + if s.State.Resources.Spot { log.Trace(ctx, "Instance interruptible") - if err := s.cliEC2.CancelSpot(ctx, s.state.RequestID); err != nil { + if err := s.cliEC2.CancelSpot(ctx, s.State.RequestID); err != nil { return gerrors.Wrap(err) } return nil } log.Trace(ctx, "Instance not interruptible") - return s.cliEC2.TerminateInstance(ctx, s.state.RequestID) + return s.cliEC2.TerminateInstance(ctx, s.State.RequestID) } @@ -245,23 +245,23 @@ func (s *AWSBackend) Requirements(ctx context.Context) models.Requirements { if s == nil { return models.Requirements{} } - if s.state == nil { + if s.State == nil { log.Trace(ctx, "State not exist") return models.Requirements{} } log.Trace(ctx, "Return model resource") - return s.state.Job.Requirements + return s.State.Job.Requirements } func (s *AWSBackend) MasterJob(ctx context.Context) *models.Job { if s == nil { return new(models.Job) } - if s.state == nil { + if s.State == nil { log.Trace(ctx, "State not exist") return nil } - theFile, err := s.cliS3.GetFile(ctx, s.bucket, fmt.Sprintf("jobs/%s/%s.yaml", s.state.Job.RepoId, s.state.Job.MasterJobID)) + theFile, err := s.cliS3.GetFile(ctx, s.bucket, fmt.Sprintf("jobs/%s/%s.yaml", s.State.Job.RepoId, s.State.Job.MasterJobID)) if err != nil { return nil } @@ -277,7 +277,7 @@ func (s *AWSBackend) CreateLogger(ctx context.Context, logGroup, logName string) if s == nil { return nil } - if s.state == nil { + if s.State == nil { log.Trace(ctx, "State not exist") return nil } @@ -285,7 +285,7 @@ func (s *AWSBackend) CreateLogger(ctx context.Context, logGroup, logName string) if s.logger == nil { log.Trace(ctx, "Create Cloudwatch") s.logger, err = NewCloudwatch(&Config{ - JobID: s.state.Job.JobID, + JobID: s.State.Job.JobID, Region: s.region, FlushInterval: 200 * time.Millisecond, }) @@ -340,10 +340,10 @@ func (s *AWSBackend) Secrets(ctx context.Context) (map[string]string, error) { if s == nil { return nil, gerrors.New("Backend is nil") } - if s.state == nil { + if s.State == nil { return nil, gerrors.New("State is empty") } - templatePath := fmt.Sprintf("secrets/%s/l;", s.state.Job.RepoId) + templatePath := fmt.Sprintf("secrets/%s/l;", s.State.Job.RepoId) listSecrets, err := s.cliS3.ListFile(ctx, s.bucket, templatePath) if err != nil { return nil, gerrors.Wrap(err) @@ -352,7 +352,7 @@ func (s *AWSBackend) Secrets(ctx context.Context) (map[string]string, error) { for _, secretPath := range listSecrets { clearName := strings.ReplaceAll(secretPath, templatePath, "") secrets[clearName] = fmt.Sprintf("%s/%s", - s.state.Job.RepoId, + s.State.Job.RepoId, clearName) } return s.cliSecret.fetchSecret(ctx, s.bucket, secrets) @@ -364,15 +364,15 @@ func (s *AWSBackend) GitCredentials(ctx context.Context) *models.GitCredentials log.Error(ctx, "Backend is empty") return nil } - if s.state == nil { + if s.State == nil { log.Error(ctx, "State is empty") return nil } - if s.state.Job == nil { + if s.State.Job == nil { log.Error(ctx, "Job is empty") return nil } - return s.cliSecret.fetchCredentials(ctx, s.bucket, s.state.Job.RepoId) + return s.cliSecret.fetchCredentials(ctx, s.bucket, s.State.Job.RepoId) } func (s *AWSBackend) GetRepoDiff(ctx context.Context, path string) (string, error) { diff --git a/runner/internal/backend/lambda/api_client.go b/runner/internal/backend/lambda/api_client.go index d7ac6ec7e..5dfc10970 100644 --- a/runner/internal/backend/lambda/api_client.go +++ b/runner/internal/backend/lambda/api_client.go @@ -15,12 +15,16 @@ type LambdaAPIClient struct { apiKey string } +type TerminateInstanceRequest struct { + InstanceIDs []string `json:"instance_ids"` +} + func NewLambdaAPIClient(apiKey string) *LambdaAPIClient { return &LambdaAPIClient{apiKey: apiKey} } -func (client *LambdaAPIClient) TerminateInstance(ctx context.Context, instancesIDs []string) error { - body, err := json.Marshal(instancesIDs) +func (client *LambdaAPIClient) TerminateInstance(ctx context.Context, instanceIDs []string) error { + body, err := json.Marshal(TerminateInstanceRequest{InstanceIDs: instanceIDs}) if err != nil { return gerrors.Wrap(err) } diff --git a/runner/internal/backend/lambda/backend.go b/runner/internal/backend/lambda/backend.go index 08b585688..2f06e153d 100644 --- a/runner/internal/backend/lambda/backend.go +++ b/runner/internal/backend/lambda/backend.go @@ -89,7 +89,7 @@ func (l *LambdaBackend) IsInterrupted(ctx context.Context) (bool, error) { } func (l *LambdaBackend) Shutdown(ctx context.Context) error { - return nil + return l.apiClient.TerminateInstance(ctx, []string{l.storageBackend.State.RequestID}) } func (l *LambdaBackend) GetArtifact(ctx context.Context, runName, localPath, remotePath string, mount bool) artifacts.Artifacter { From 9af46c0b106514ddec264ad73f8e4e86866f83f0 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 27 Jun 2023 11:41:32 +0500 Subject: [PATCH 07/16] Fix done job being marked as failed after termination --- cli/dstack/_internal/backend/base/runs.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cli/dstack/_internal/backend/base/runs.py b/cli/dstack/_internal/backend/base/runs.py index 3e07f853e..68214524e 100644 --- a/cli/dstack/_internal/backend/base/runs.py +++ b/cli/dstack/_internal/backend/base/runs.py @@ -102,9 +102,13 @@ def _create_run( job.error_code = JobErrorCode.INTERRUPTED_BY_NO_CAPACITY jobs.update_job(storage, job) elif request_head.status == RequestStatus.TERMINATED: - job.status = job_head.status = JobStatus.FAILED - job.error_code = JobErrorCode.INSTANCE_TERMINATED - jobs.update_job(storage, job) + # We should check the job status again, to ensure it hasn't been updated just + # before the instance was terminated + job = jobs.get_job(storage, job_head.repo_ref.repo_id, job_head.job_id) + if job.status.is_unfinished(): + job.status = job_head.status = JobStatus.FAILED + job.error_code = JobErrorCode.INSTANCE_TERMINATED + jobs.update_job(storage, job) run_head = RunHead( run_name=job_head.run_name, workflow_name=job_head.workflow_name, @@ -172,8 +176,10 @@ def _update_run( job.error_code = JobErrorCode.INTERRUPTED_BY_NO_CAPACITY jobs.update_job(storage, job) elif request_head.status == RequestStatus.TERMINATED: - job.status = job_head.status = JobStatus.FAILED - job.error_code = JobErrorCode.INSTANCE_TERMINATED - jobs.update_job(storage, job) + job = jobs.get_job(storage, job_head.repo_ref.repo_id, job_head.job_id) + if job.status.is_unfinished(): + job.status = job_head.status = JobStatus.FAILED + job.error_code = JobErrorCode.INSTANCE_TERMINATED + jobs.update_job(storage, job) run.status = job_head.status run.job_heads.append(job_head) From e2772ffa5b398a604d306bbeaad5bb8abfa3338a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 27 Jun 2023 11:41:59 +0500 Subject: [PATCH 08/16] Set up lambda for prod and stage --- cli/dstack/_internal/scripts/setup_lambda.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/dstack/_internal/scripts/setup_lambda.sh b/cli/dstack/_internal/scripts/setup_lambda.sh index 3c3061386..5c2349794 100644 --- a/cli/dstack/_internal/scripts/setup_lambda.sh +++ b/cli/dstack/_internal/scripts/setup_lambda.sh @@ -15,14 +15,14 @@ function install_stgn { } function install_prod { - sudo curl --output /usr/local/bin/dstack-runner "https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/latest/binaries/dstack-runner-linux-amd64" + sudo curl --output /usr/local/bin/dstack-runner "https://dstack-runner-downloads.s3.eu-west-1.amazonaws.com/${RUNNER_VERSION}/binaries/dstack-runner-linux-amd64" sudo chmod +x /usr/local/bin/dstack-runner dstack-runner --version } install_nvidia_docker_runtime -if [[ $DSTACK_STAGE == "PROD" ]]; then +if [[ $ENVIRONMENT == "PROD" ]]; then install_prod else install_stgn From bb041db702640df3010c393d979e0e7159db16b1 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 27 Jun 2023 11:42:39 +0500 Subject: [PATCH 09/16] Check available regions on lambda --- cli/dstack/_internal/backend/base/compute.py | 1 + .../_internal/backend/lambdalabs/compute.py | 34 +++++++++++++++---- cli/dstack/_internal/core/instance.py | 3 ++ 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/cli/dstack/_internal/backend/base/compute.py b/cli/dstack/_internal/backend/base/compute.py index 171954edc..cd08fa73d 100644 --- a/cli/dstack/_internal/backend/base/compute.py +++ b/cli/dstack/_internal/backend/base/compute.py @@ -66,6 +66,7 @@ def choose_instance_type( spot=spot, local=False, ), + available_regions=instance_type.available_regions, ) diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py index 22984c1a0..25d440d9f 100644 --- a/cli/dstack/_internal/backend/lambdalabs/compute.py +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -3,11 +3,13 @@ import subprocess import tempfile import time +from threading import Thread from typing import Dict, List, Optional, Tuple import pkg_resources import yaml +from dstack import version from dstack._internal.backend.base.compute import WS_PORT, NoCapacityError, choose_instance_type from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml @@ -47,7 +49,7 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]: def run_instance(self, job: Job, instance_type: InstanceType) -> str: return _run_instance( api_client=self.api_client, - region_name="us-west-1", + region_name=instance_type.available_regions[0], instance_type_name=instance_type.instance_name, user_ssh_key=job.ssh_key_pub, hub_ssh_key=get_hub_ssh_public_key(), @@ -83,7 +85,8 @@ def _get_instance_info(api_client: LambdaAPIClient, instance_id: str) -> Optiona def _instance_type_data_to_instance_type(instance_type_data: Dict) -> Optional[InstanceType]: instance_type = instance_type_data["instance_type"] - regions = instance_type_data["regions_with_capacity_available"] + regions_data = instance_type_data["regions_with_capacity_available"] + regions = [r["name"] for r in regions_data] if len(regions) == 0: return None instance_type_specs = instance_type["specs"] @@ -99,6 +102,7 @@ def _instance_type_data_to_instance_type(instance_type_data: Dict) -> Optional[I spot=False, local=False, ), + available_regions=regions, ) @@ -163,9 +167,15 @@ def _run_instance( instance_id = instances_ids[0] # instance_id = api_client.list_instances()[0]["id"] instance_info = _wait_for_instance(api_client, instance_id) - hostname = instance_info["ip"] - _setup_instance(hostname=hostname, user_ssh_key=user_ssh_key) - _launch_runner(hostname=hostname, launch_script=launch_script) + thread = Thread( + target=_start_runner, + kwargs={ + "hostname": instance_info["ip"], + "user_ssh_key": user_ssh_key, + "launch_script": launch_script, + }, + ) + thread.start() return instance_id @@ -202,6 +212,11 @@ def _wait_for_instance( time.sleep(_WAIT_FOR_INSTANCE_INTERVAL) +def _start_runner(hostname: str, user_ssh_key: str, launch_script: str): + _setup_instance(hostname, user_ssh_key) + _launch_runner(hostname, launch_script) + + def _setup_instance(hostname: str, user_ssh_key: str): setup_script_path = pkg_resources.resource_filename( "dstack._internal", "scripts/setup_lambda.sh" @@ -212,12 +227,19 @@ def _setup_instance(hostname: str, user_ssh_key: str): # so we have to update authorized_keys manually to add the user key setup_commands = ( "chmod +x .dstack/setup_lambda.sh && " - "RUNNER=1349 .dstack/setup_lambda.sh && " + f"{_get_setup_command()} && " f"echo '{user_ssh_key}' >> /home/ubuntu/.ssh/authorized_keys" ) _run_ssh_command(hostname=hostname, commands=setup_commands) +def _get_setup_command() -> str: + if not version.__is_release__: + # TODO: replace with latest after merge + return "ENVIRONMENT=stage RUNNER=1351 .dstack/setup_lambda.sh" + return f"ENVIRONMENT=prod RUNNER={version.__version__} .dstack/setup_lambda.sh" + + def _launch_runner(hostname: str, launch_script: str): launch_script = launch_script.replace("HOSTNAME_PLACEHOLDER", hostname) with tempfile.NamedTemporaryFile("w+") as f: diff --git a/cli/dstack/_internal/core/instance.py b/cli/dstack/_internal/core/instance.py index 88c449581..0b297f0bd 100644 --- a/cli/dstack/_internal/core/instance.py +++ b/cli/dstack/_internal/core/instance.py @@ -1,3 +1,5 @@ +from typing import List, Optional + from pydantic import BaseModel from dstack._internal.core.runners import Resources @@ -6,3 +8,4 @@ class InstanceType(BaseModel): instance_name: str resources: Resources + available_regions: Optional[List[str]] = None From 40ce212f322eb3b8f980f9efa01e120f3a763abd Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 28 Jun 2023 12:13:12 +0500 Subject: [PATCH 10/16] Implement lambda configurator --- .../_internal/backend/lambdalabs/config.py | 1 + cli/dstack/_internal/hub/models/__init__.py | 32 +++- cli/dstack/_internal/hub/routers/projects.py | 6 +- .../backends/lambdalabs/configurator.py | 159 ++++++++++++++++-- 4 files changed, 182 insertions(+), 16 deletions(-) diff --git a/cli/dstack/_internal/backend/lambdalabs/config.py b/cli/dstack/_internal/backend/lambdalabs/config.py index e47688d71..fa6bf4ef0 100644 --- a/cli/dstack/_internal/backend/lambdalabs/config.py +++ b/cli/dstack/_internal/backend/lambdalabs/config.py @@ -10,6 +10,7 @@ class AWSStorageConfigCredentials(BaseModel): class AWSStorageConfig(BaseModel): + backend: Literal["aws"] = "aws" bucket: str region: str credentials: AWSStorageConfigCredentials diff --git a/cli/dstack/_internal/hub/models/__init__.py b/cli/dstack/_internal/hub/models/__init__.py index d2b5ddc3e..5e4a402e7 100644 --- a/cli/dstack/_internal/hub/models/__init__.py +++ b/cli/dstack/_internal/hub/models/__init__.py @@ -164,12 +164,35 @@ class AzureProjectConfigWithCreds(AzureProjectConfig): credentials: AzureProjectCreds +class AWSStorageProjectConfigWithCredsPartial(BaseModel): + type: Literal["aws"] = "aws" + bucket_name: Optional[str] + credentials: Optional[AWSProjectAccessKeyCreds] + + +class AWSStorageProjectConfig(BaseModel): + type: Literal["aws"] = "aws" + bucket_name: str + + +class AWSStorageProjectConfigWithCreds(AWSStorageProjectConfig): + credentials: AWSProjectAccessKeyCreds + + +class LambdaProjectConfigWithCredsPartial(BaseModel): + type: Literal["lambda"] = "lambda" + api_key: Optional[str] + storage_backend: Optional[AWSStorageProjectConfigWithCredsPartial] + + class LambdaProjectConfig(BaseModel): type: Literal["lambda"] = "lambda" + storage_backend: AWSStorageProjectConfig class LambdaProjectConfigWithCreds(LambdaProjectConfig): api_key: str + storage_backend: AWSStorageProjectConfigWithCreds AnyProjectConfig = Union[ @@ -180,7 +203,7 @@ class LambdaProjectConfigWithCreds(LambdaProjectConfig): AWSProjectConfigWithCredsPartial, GCPProjectConfigWithCredsPartial, AzureProjectConfigWithCredsPartial, - LambdaProjectConfig, + LambdaProjectConfigWithCredsPartial, ] AnyProjectConfigWithCreds = Union[ LocalProjectConfig, @@ -378,8 +401,15 @@ class AzureProjectValues(BaseModel): storage_account: Optional[ProjectElement] +class AWSStorageBackendValues(BaseModel): + type: Literal["aws"] = "aws" + bucket_name: Optional[ProjectElement] + + class LambdaProjectValues(BaseModel): type: Literal["lambda"] = "lambda" + storage_backend_type: ProjectElement + storage_backend_values: Optional[AWSStorageBackendValues] class ProjectValues(BaseModel): diff --git a/cli/dstack/_internal/hub/routers/projects.py b/cli/dstack/_internal/hub/routers/projects.py index f43f47297..55c22a659 100644 --- a/cli/dstack/_internal/hub/routers/projects.py +++ b/cli/dstack/_internal/hub/routers/projects.py @@ -33,7 +33,7 @@ async def get_backend_config_values( ) -> ProjectValues: configurator = get_backend_configurator(config.__root__.type) try: - result = await run_async(configurator.configure_project, config.__root__.dict()) + result = await run_async(configurator.configure_project, config.__root__) except BackendConfigError as e: _error_response_on_config_error(e, path_to_config=[]) return result @@ -62,7 +62,7 @@ async def create_project( ) configurator = get_backend_configurator(project_info.backend.__root__.type) try: - await run_async(configurator.configure_project, project_info.backend.__root__.dict()) + await run_async(configurator.configure_project, project_info.backend.__root__) await ProjectManager.create_project_from_info(user=user, project_info=project_info) except BackendConfigError as e: _error_response_on_config_error(e, path_to_config=["backend"]) @@ -113,7 +113,7 @@ async def update_project( ) -> ProjectInfoWithCreds: configurator = get_backend_configurator(project_info.backend.__root__.type) try: - await run_async(configurator.configure_project, project_info.backend.__root__.dict()) + await run_async(configurator.configure_project, project_info.backend.__root__) except BackendConfigError as e: _error_response_on_config_error(e, path_to_config=["backend"]) await ProjectManager.update_project_from_info(project_info) diff --git a/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py b/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py index 9bf7087a7..5648fed30 100644 --- a/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py @@ -1,7 +1,12 @@ -import os -from typing import Dict, Tuple, Union +import json +from typing import Dict, Optional, Tuple, Union + +import botocore +from boto3.session import Session +from requests import HTTPError from dstack._internal.backend.lambdalabs import LambdaBackend +from dstack._internal.backend.lambdalabs.api_client import LambdaAPIClient from dstack._internal.backend.lambdalabs.config import ( AWSStorageConfig, AWSStorageConfigCredentials, @@ -9,10 +14,20 @@ ) from dstack._internal.hub.db.models import Project from dstack._internal.hub.models import ( + AWSProjectAccessKeyCreds, + AWSStorageBackendValues, + AWSStorageProjectConfig, + AWSStorageProjectConfigWithCreds, + AWSStorageProjectConfigWithCredsPartial, LambdaProjectConfig, LambdaProjectConfigWithCreds, + LambdaProjectConfigWithCredsPartial, + LambdaProjectValues, + ProjectElement, + ProjectElementValue, ProjectValues, ) +from dstack._internal.hub.services.backends.base import BackendConfigError class LambdaConfigurator: @@ -21,32 +36,152 @@ class LambdaConfigurator: def get_backend_class(self) -> type: return LambdaBackend - def configure_project(self, config_data: Dict) -> ProjectValues: - return None + def configure_project(self, config: LambdaProjectConfigWithCredsPartial) -> ProjectValues: + selected_storage_backend = None + if config.storage_backend is not None: + selected_storage_backend = config.storage_backend.type + storage_backend_type = self._get_storage_backend_type_element( + selected=selected_storage_backend + ) + project_values = LambdaProjectValues(storage_backend_type=storage_backend_type) + if config.api_key is None: + return project_values + self._validate_lambda_api_key(api_key=config.api_key) + if config.storage_backend is None or config.storage_backend.credentials is None: + return project_values + project_values.storage_backend_values = self._get_aws_storage_backend_values( + config.storage_backend + ) + return project_values def create_config_auth_data_from_project_config( self, project_config: LambdaProjectConfigWithCreds ) -> Tuple[Dict, Dict]: - config_data = LambdaProjectConfig.parse_obj(project_config).dict() - auth_data = {"api_key": project_config.api_key} + config_data = { + "backend": "lambda", + "storage_backend": self._get_aws_storage_backend_config_data( + project_config.storage_backend + ), + } + auth_data = { + "api_key": project_config.api_key, + "storage_backend": {"credentials": project_config.storage_backend.credentials.dict()}, + } return config_data, auth_data def get_project_config_from_project( self, project: Project, include_creds: bool ) -> Union[LambdaProjectConfig, LambdaProjectConfigWithCreds]: - return LambdaProjectConfig() + config_data = json.loads(project.config) + if include_creds: + auth_data = json.loads(project.auth) + return LambdaProjectConfigWithCreds( + api_key=auth_data["api_key"], + storage_backend=AWSStorageProjectConfigWithCreds( + bucket_name=config_data["storage_backend"]["bucket"], + credentials=AWSProjectAccessKeyCreds.parse_obj( + auth_data["storage_backend"]["credentials"] + ), + ), + ) + return LambdaProjectConfig( + storage_backend=AWSStorageProjectConfig( + bucket_name=config_data["storage_backend"]["bucket"] + ) + ) def get_backend_config_from_hub_config_data( self, project_name: str, config_data: Dict, auth_data: Dict ) -> LambdaConfig: return LambdaConfig( - api_key=os.environ["LAMBDA_API_KEY"], + api_key=auth_data["api_key"], storage_config=AWSStorageConfig( - region="eu-west-1", - bucket="dstack-lambda-eu-west-1", + bucket=config_data["storage_backend"]["bucket"], + region=config_data["storage_backend"]["region"], credentials=AWSStorageConfigCredentials( - access_key=os.environ["AWS_ACCESS_KEY_ID"], - secret_key=os.environ["AWS_SECRET_ACCESS_KEY"], + access_key=auth_data["storage_backend"]["credentials"]["access_key"], + secret_key=auth_data["storage_backend"]["credentials"]["secret_key"], ), ), ) + + def _get_storage_backend_type_element(self, selected: Optional[str]) -> ProjectElement: + element = ProjectElement( + values=[ProjectElementValue(value="aws", label="AWS")], selected="aws" + ) + return element + + def _validate_lambda_api_key(self, api_key: str): + client = LambdaAPIClient(api_key=api_key) + try: + client.list_instance_types() + except HTTPError as e: + if e.response.status_code in [401, 403]: + raise BackendConfigError( + "Invalid credentials", + code="invalid_credentials", + fields=[["api_key"]], + ) + raise e + + def _get_aws_storage_backend_values( + self, config: AWSStorageProjectConfigWithCredsPartial + ) -> AWSStorageBackendValues: + session = Session( + aws_access_key_id=config.credentials.access_key, + aws_secret_access_key=config.credentials.secret_key, + ) + self._validate_aws_credentials(session=session) + storage_backend_values = AWSStorageBackendValues() + storage_backend_values.bucket_name = self._get_aws_bucket_element( + session=session, selected=config.bucket_name + ) + return storage_backend_values + + def _validate_aws_credentials(self, session: Session): + sts = session.client("sts") + try: + sts.get_caller_identity() + except botocore.exceptions.ClientError: + raise BackendConfigError( + "Invalid credentials", + code="invalid_credentials", + fields=[ + ["storage_backend", "credentials", "access_key"], + ["storage_backend", "credentials", "secret_key"], + ], + ) + + def _get_aws_bucket_element( + self, session: Session, selected: Optional[str] = None + ) -> ProjectElement: + element = ProjectElement(selected=selected) + s3_client = session.client("s3") + response = s3_client.list_buckets() + for bucket in response["Buckets"]: + element.values.append( + ProjectElementValue( + value=bucket["Name"], + label=bucket["Name"], + ) + ) + return element + + def _get_aws_storage_backend_config_data( + self, config: AWSStorageProjectConfigWithCreds + ) -> Dict: + session = Session( + aws_access_key_id=config.credentials.access_key, + aws_secret_access_key=config.credentials.secret_key, + ) + return { + "backend": "aws", + "bucket": config.bucket_name, + "region": self._get_aws_bucket_region(session, config.bucket_name), + } + + def _get_aws_bucket_region(self, session: Session, bucket: str) -> str: + s3_client = session.client("s3") + response = s3_client.head_bucket(Bucket=bucket) + region = response["ResponseMetadata"]["HTTPHeaders"]["x-amz-bucket-region"] + return region From fb63a5f315c59b307a1c7d86ba3ea3504bdaeba6 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 28 Jun 2023 14:59:54 +0500 Subject: [PATCH 11/16] Refactor backend configurators --- .../_internal/backend/lambdalabs/config.py | 2 - .../_internal/hub/repository/projects.py | 7 +- cli/dstack/_internal/hub/routers/cache.py | 8 +- .../hub/services/backends/aws/configurator.py | 79 ++++---- .../services/backends/azure/configurator.py | 87 +++++---- .../_internal/hub/services/backends/base.py | 25 ++- .../hub/services/backends/gcp/configurator.py | 176 +++++++++--------- .../backends/lambdalabs/configurator.py | 44 ++--- .../services/backends/local/configurator.py | 26 +-- 9 files changed, 232 insertions(+), 222 deletions(-) diff --git a/cli/dstack/_internal/backend/lambdalabs/config.py b/cli/dstack/_internal/backend/lambdalabs/config.py index fa6bf4ef0..55770aae0 100644 --- a/cli/dstack/_internal/backend/lambdalabs/config.py +++ b/cli/dstack/_internal/backend/lambdalabs/config.py @@ -1,5 +1,3 @@ -from typing import Dict, Optional - from pydantic import BaseModel from typing_extensions import Literal diff --git a/cli/dstack/_internal/hub/repository/projects.py b/cli/dstack/_internal/hub/repository/projects.py index 9d689bcf3..970fa4992 100644 --- a/cli/dstack/_internal/hub/repository/projects.py +++ b/cli/dstack/_internal/hub/repository/projects.py @@ -1,4 +1,3 @@ -import asyncio import json from typing import List, Optional, Union @@ -162,9 +161,7 @@ async def _info2project(project_info: ProjectInfoWithCreds) -> Project: backend=project_info.backend.type, ) configurator = get_configurator(project.backend) - config, auth = await run_async( - configurator.create_config_auth_data_from_project_config, project_info.backend - ) + config, auth = await run_async(configurator.create_project, project_info.backend) project.config = json.dumps(config) project.auth = json.dumps(auth) return project @@ -184,7 +181,7 @@ def _project2info( configurator = get_configurator(project.backend) if configurator is None: return None - backend = configurator.get_project_config_from_project(project, include_creds=include_creds) + backend = configurator.get_project_config(project, include_creds=include_creds) if include_creds: return ProjectInfoWithCreds(project_name=project.name, backend=backend, members=members) return ProjectInfo(project_name=project.name, backend=backend, members=members) diff --git a/cli/dstack/_internal/hub/routers/cache.py b/cli/dstack/_internal/hub/routers/cache.py index ad2171351..02d21f500 100644 --- a/cli/dstack/_internal/hub/routers/cache.py +++ b/cli/dstack/_internal/hub/routers/cache.py @@ -16,13 +16,7 @@ async def get_backend(project: Project) -> Optional[Backend]: key = project.name if cache.get(key) is not None: return cache[key] - json_data = json.loads(str(project.config)) - auth_data = json.loads(str(project.auth)) - config = configurator.get_backend_config_from_hub_config_data( - project.name, json_data, auth_data - ) - backend_cls = configurator.get_backend_class() - backend = await run_async(backend_cls, config) + backend = await run_async(configurator.get_backend, project) cache[key] = backend return cache[key] diff --git a/cli/dstack/_internal/hub/services/backends/aws/configurator.py b/cli/dstack/_internal/hub/services/backends/aws/configurator.py index 282ccc3a5..bd0a89161 100644 --- a/cli/dstack/_internal/hub/services/backends/aws/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/aws/configurator.py @@ -6,13 +6,13 @@ from dstack._internal.backend.aws import AwsBackend from dstack._internal.backend.aws.config import AWSConfig -from dstack._internal.backend.base.config import BackendConfig from dstack._internal.hub.db.models import Project from dstack._internal.hub.models import ( AWSBucketProjectElement, AWSBucketProjectElementValue, AWSProjectConfig, AWSProjectConfigWithCreds, + AWSProjectConfigWithCredsPartial, AWSProjectCreds, AWSProjectValues, ProjectElement, @@ -20,7 +20,7 @@ ) from dstack._internal.hub.services.backends.base import BackendConfigError, Configurator -regions = [ +REGIONS = [ ("US East, N. Virginia", "us-east-1"), ("US East, Ohio", "us-east-2"), ("US West, N. California", "us-west-1"), @@ -38,31 +38,30 @@ class AWSConfigurator(Configurator): NAME = "aws" - def get_backend_class(self) -> type: - return AwsBackend - - def configure_project(self, config_data: Dict) -> AWSProjectValues: - config = AWSConfig.deserialize(config_data) - - if config.region_name is not None and config.region_name not in {r[1] for r in regions}: - raise BackendConfigError(f"Invalid AWS region {config.region_name}") + def configure_project( + self, project_config: AWSProjectConfigWithCredsPartial + ) -> AWSProjectValues: + if project_config.region_name is not None and project_config.region_name not in { + r[1] for r in REGIONS + }: + raise BackendConfigError(f"Invalid AWS region {project_config.region_name}") project_values = AWSProjectValues() session = Session() if session.region_name is None: - session = Session(region_name=config.region_name) + session = Session(region_name=project_config.region_name) project_values.default_credentials = self._valid_credentials(session=session) - credentials_data = config_data.get("credentials") - if credentials_data is None: + if project_config.credentials is None: return project_values - if credentials_data["type"] == "access_key": + project_credentials = project_config.credentials.__root__ + if project_credentials.type == "access_key": session = Session( - region_name=config.region_name, - aws_access_key_id=credentials_data["access_key"], - aws_secret_access_key=credentials_data["secret_key"], + region_name=project_config.region_name, + aws_access_key_id=project_credentials.access_key, + aws_secret_access_key=project_credentials.secret_key, ) if not self._valid_credentials(session=session): self._raise_invalid_credentials_error( @@ -74,27 +73,25 @@ def configure_project(self, config_data: Dict) -> AWSProjectValues: # TODO validate config values project_values.region_name = self._get_hub_regions(default_region=session.region_name) project_values.s3_bucket_name = self._get_hub_buckets( - session=session, region=config.region_name, default_bucket=config.bucket_name + session=session, + region=project_config.region_name, + default_bucket=project_config.s3_bucket_name, ) project_values.ec2_subnet_id = self._get_hub_subnet( - session=session, default_subnet=config.subnet_id + session=session, default_subnet=project_config.ec2_subnet_id ) return project_values - def create_config_auth_data_from_project_config( - self, project_config: AWSProjectConfigWithCreds - ) -> Tuple[Dict, Dict]: - project_config.s3_bucket_name = project_config.s3_bucket_name.replace("s3://", "") - config = AWSProjectConfig.parse_obj(project_config).dict() - auth = project_config.credentials.__root__.dict() - return config, auth - - def get_backend_config_from_hub_config_data( - self, project_name: str, config_data: Dict, auth_data: Dict - ) -> BackendConfig: - return AWSConfig.deserialize(config_data, auth_data) + def create_project(self, project_config: AWSProjectConfigWithCreds) -> Tuple[Dict, Dict]: + config_data = { + "region_name": project_config.region_name, + "s3_bucket_name": project_config.s3_bucket_name.replace("s3://", ""), + "ec2_subnet_id": project_config.ec2_subnet_id, + } + auth_data = project_config.credentials.__root__.dict() + return config_data, auth_data - def get_project_config_from_project( + def get_project_config( self, project: Project, include_creds: bool ) -> Union[AWSProjectConfig, AWSProjectConfigWithCreds]: json_config = json.loads(project.config) @@ -117,6 +114,22 @@ def get_project_config_from_project( ec2_subnet_id=ec2_subnet_id, ) + def get_backend(self, project: Project) -> AwsBackend: + config_data = json.loads(project.config) + auth_data = json.loads(project.auth) + config = AWSConfig( + bucket_name=config_data.get("bucket") + or config_data.get("bucket_name") + or config_data.get("s3_bucket_name"), + region_name=config_data.get("region_name"), + profile_name=config_data.get("profile_name"), + subnet_id=config_data.get("subnet_id") + or config_data.get("ec2_subnet_id") + or config_data.get("subnet"), + credentials=auth_data, + ) + return AwsBackend(config) + def _valid_credentials(self, session: Session) -> bool: sts = session.client("sts") try: @@ -134,7 +147,7 @@ def _raise_invalid_credentials_error(self, fields: Optional[List[List[str]]] = N def _get_hub_regions(self, default_region: Optional[str]) -> ProjectElement: element = ProjectElement(selected=default_region) - for r in regions: + for r in REGIONS: element.values.append(ProjectElementValue(value=r[1], label=r[0])) return element diff --git a/cli/dstack/_internal/hub/services/backends/azure/configurator.py b/cli/dstack/_internal/hub/services/backends/azure/configurator.py index d4d9d4b82..11db7184d 100644 --- a/cli/dstack/_internal/hub/services/backends/azure/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/azure/configurator.py @@ -52,6 +52,7 @@ from dstack._internal.hub.models import ( AzureProjectConfig, AzureProjectConfigWithCreds, + AzureProjectConfigWithCredsPartial, AzureProjectCreds, AzureProjectValues, ProjectElement, @@ -89,31 +90,33 @@ class AzureConfigurator(Configurator): def get_backend_class(self) -> type: return AzureBackend - def configure_project(self, config_data: Dict) -> AzureProjectValues: + def configure_project( + self, project_config: AzureProjectConfigWithCredsPartial + ) -> AzureProjectValues: project_values = AzureProjectValues() self.credential = DefaultAzureCredential() try: - project_values.tenant_id = self._get_tenant_id( - default_tenant_id=config_data.get("tenant_id") + project_values.tenant_id = self._get_tenant_id_element( + selected=project_config.tenant_id ) except ClientAuthenticationError: project_values.default_credentials = False else: project_values.default_credentials = True - credentials_data = config_data.get("credentials") - if credentials_data is None: + if project_config.credentials is None: return project_values - if credentials_data["type"] == "client": + project_credentials = project_config.credentials.__root__ + if project_credentials.type == "client": self.credential = ClientSecretCredential( - tenant_id=config_data.get("tenant_id"), - client_id=credentials_data["client_id"], - client_secret=credentials_data["client_secret"], + tenant_id=project_config.tenant_id, + client_id=project_credentials.client_id, + client_secret=project_credentials.client_secret, ) try: - project_values.tenant_id = self._get_tenant_id( - default_tenant_id=config_data.get("tenant_id") + project_values.tenant_id = self._get_tenant_id_element( + selected=project_config.tenant_id ) except ClientAuthenticationError: self._raise_invalid_credentials_error( @@ -129,24 +132,22 @@ def configure_project(self, config_data: Dict) -> AzureProjectValues: self.tenant_id = project_values.tenant_id.selected if self.tenant_id is None: return project_values - project_values.subscription_id = self._get_subscription_id( - default_subscription_id=config_data.get("subscription_id") + project_values.subscription_id = self._get_subscription_id_element( + selected=project_config.subscription_id ) self.subscription_id = project_values.subscription_id.selected if self.subscription_id is None: return project_values - project_values.location = self._get_location(default_location=config_data.get("location")) + project_values.location = self._get_location_element(selected=project_config.location) self.location = project_values.location.selected if self.location is None: return project_values - project_values.storage_account = self._get_storage_account( - default_storage_account=config_data.get("storage_account") + project_values.storage_account = self._get_storage_account_element( + selected=project_config.storage_account ) return project_values - def create_config_auth_data_from_project_config( - self, project_config: AzureProjectConfigWithCreds - ) -> Tuple[Dict, Dict]: + def create_project(self, project_config: AzureProjectConfigWithCreds) -> Tuple[Dict, Dict]: self.tenant_id = project_config.tenant_id self.subscription_id = project_config.subscription_id self.location = project_config.location @@ -167,7 +168,6 @@ def create_config_auth_data_from_project_config( self._create_logs_resources() self._grant_roles_or_error() config_data = { - "backend": "azure", "tenant_id": self.tenant_id, "subscription_id": self.subscription_id, "location": self.location, @@ -180,12 +180,7 @@ def create_config_auth_data_from_project_config( auth_data = project_config.credentials.__root__.dict() return config_data, auth_data - def get_backend_config_from_hub_config_data( - self, project_name: str, config_data: Dict, auth_data: Dict - ) -> AzureConfig: - return AzureConfig.deserialize({**config_data, "credentials": auth_data}) - - def get_project_config_from_project( + def get_project_config( self, project: Project, include_creds: bool ) -> Union[AzureProjectConfig, AzureProjectConfigWithCreds]: json_config = json.loads(project.config) @@ -209,6 +204,22 @@ def get_project_config_from_project( storage_account=storage_account, ) + def get_backend(self, project: Project) -> AzureBackend: + config_data = json.loads(project.config) + auth_data = json.loads(project.auth) + config = AzureConfig( + tenant_id=config_data["tenant_id"], + subscription_id=config_data["subscription_id"], + location=config_data["location"], + resource_group=config_data["resource_group"], + storage_account=config_data["storage_account"], + vault_url=config_data["vault_url"], + network=config_data["network"], + subnet=config_data["subnet"], + credentials=auth_data, + ) + return AzureBackend(config) + def _raise_invalid_credentials_error(self, fields: Optional[List[List[str]]] = None): raise BackendConfigError( "Invalid credentials", @@ -216,16 +227,16 @@ def _raise_invalid_credentials_error(self, fields: Optional[List[List[str]]] = N fields=fields, ) - def _get_tenant_id(self, default_tenant_id: Optional[str]) -> ProjectElement: + def _get_tenant_id_element(self, selected: Optional[str]) -> ProjectElement: subscription_client = SubscriptionClient(credential=self.credential) - element = ProjectElement(selected=default_tenant_id) + element = ProjectElement(selected=selected) tenant_ids = [] for tenant in subscription_client.tenants.list(): tenant_ids.append(tenant.tenant_id) element.values.append( ProjectElementValue(value=tenant.tenant_id, label=tenant.tenant_id) ) - if default_tenant_id is not None and default_tenant_id not in tenant_ids: + if selected is not None and selected not in tenant_ids: raise BackendConfigError( "Invalid tenant_id", code="invalid_tenant_id", fields=[["tenant_id"]] ) @@ -233,9 +244,9 @@ def _get_tenant_id(self, default_tenant_id: Optional[str]) -> ProjectElement: element.selected = tenant_ids[0] return element - def _get_subscription_id(self, default_subscription_id: Optional[str]) -> ProjectElement: + def _get_subscription_id_element(self, selected: Optional[str]) -> ProjectElement: subscription_client = SubscriptionClient(credential=self.credential) - element = ProjectElement(selected=default_subscription_id) + element = ProjectElement(selected=selected) subscription_ids = [] for subscription in subscription_client.subscriptions.list(): subscription_ids.append(subscription.subscription_id) @@ -245,7 +256,7 @@ def _get_subscription_id(self, default_subscription_id: Optional[str]) -> Projec label=f"{subscription.display_name} ({subscription.subscription_id})", ) ) - if default_subscription_id is not None and default_subscription_id not in subscription_ids: + if selected is not None and selected not in subscription_ids: raise BackendConfigError( "Invalid subscription_id", code="invalid_subscription_id", @@ -261,12 +272,12 @@ def _get_subscription_id(self, default_subscription_id: Optional[str]) -> Projec ) return element - def _get_location(self, default_location: Optional[str]) -> ProjectElement: - if default_location is not None and default_location not in LOCATION_VALUES: + def _get_location_element(self, selected: Optional[str]) -> ProjectElement: + if selected is not None and selected not in LOCATION_VALUES: raise BackendConfigError( "Invalid location", code="invalid_location", fields=[["location"]] ) - element = ProjectElement(selected=default_location) + element = ProjectElement(selected=selected) for l in LOCATIONS: element.values.append( ProjectElementValue( @@ -276,17 +287,17 @@ def _get_location(self, default_location: Optional[str]) -> ProjectElement: ) return element - def _get_storage_account(self, default_storage_account: Optional[str]) -> ProjectElement: + def _get_storage_account_element(self, selected: Optional[str]) -> ProjectElement: client = StorageManagementClient( credential=self.credential, subscription_id=self.subscription_id ) - element = ProjectElement(selected=default_storage_account) + element = ProjectElement(selected=selected) storage_accounts = [] for sa in client.storage_accounts.list(): if sa.provisioning_state == "Succeeded" and sa.location == self.location: storage_accounts.append(sa.name) element.values.append(ProjectElementValue(value=sa.name, label=sa.name)) - if default_storage_account is not None and default_storage_account not in storage_accounts: + if selected is not None and selected not in storage_accounts: raise BackendConfigError( "Invalid storage_account", code="invalid_storage_account", diff --git a/cli/dstack/_internal/hub/services/backends/base.py b/cli/dstack/_internal/hub/services/backends/base.py index 78e54853d..f206b0825 100644 --- a/cli/dstack/_internal/hub/services/backends/base.py +++ b/cli/dstack/_internal/hub/services/backends/base.py @@ -1,9 +1,14 @@ from abc import ABC, abstractmethod from typing import Dict, List, Tuple, Union -from dstack._internal.backend.base.config import BackendConfig +from dstack._internal.backend.base import Backend from dstack._internal.hub.db.models import Project -from dstack._internal.hub.models import AnyProjectConfig, AnyProjectConfigWithCreds, ProjectValues +from dstack._internal.hub.models import ( + AnyProjectConfig, + AnyProjectConfigWithCreds, + AnyProjectConfigWithCredsPartial, + ProjectValues, +) class BackendConfigError(Exception): @@ -17,27 +22,19 @@ class Configurator(ABC): NAME = None @abstractmethod - def get_backend_class(self) -> type: + def configure_project(self, project_config: AnyProjectConfigWithCredsPartial) -> ProjectValues: pass @abstractmethod - def configure_project(self, config_data: Dict) -> ProjectValues: + def create_project(self, project_config: AnyProjectConfigWithCreds) -> Tuple[Dict, Dict]: pass @abstractmethod - def create_config_auth_data_from_project_config( - self, project_config: AnyProjectConfigWithCreds - ) -> Tuple[Dict, Dict]: - pass - - @abstractmethod - def get_project_config_from_project( + def get_project_config( self, project: Project, include_creds: bool ) -> Union[AnyProjectConfig, AnyProjectConfigWithCreds]: pass @abstractmethod - def get_backend_config_from_hub_config_data( - self, project_name: str, config_data: Dict, auth_data: Dict - ) -> BackendConfig: + def get_backend(self, project: Project) -> Backend: pass diff --git a/cli/dstack/_internal/hub/services/backends/gcp/configurator.py b/cli/dstack/_internal/hub/services/backends/gcp/configurator.py index e5b93995c..8597e7ee6 100644 --- a/cli/dstack/_internal/hub/services/backends/gcp/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/gcp/configurator.py @@ -16,6 +16,7 @@ from dstack._internal.hub.models import ( GCPProjectConfig, GCPProjectConfigWithCreds, + GCPProjectConfigWithCredsPartial, GCPProjectCreds, GCPProjectValues, GCPVPCSubnetProjectElement, @@ -114,7 +115,9 @@ class GCPConfigurator(Configurator): def get_backend_class(self) -> type: return GCPBackend - def configure_project(self, config_data: Dict) -> GCPProjectValues: + def configure_project( + self, project_config: GCPProjectConfigWithCredsPartial + ) -> GCPProjectValues: project_values = GCPProjectValues() try: self.credentials, self.project_id = google.auth.default() @@ -123,13 +126,13 @@ def configure_project(self, config_data: Dict) -> GCPProjectValues: else: project_values.default_credentials = True - credentials_data = config_data.get("credentials") - if credentials_data is None: + if project_config.credentials is None: return project_values - if credentials_data["type"] == "service_account": + project_credentials = project_config.credentials.__root__ + if project_credentials.type == "service_account": try: - self._auth(credentials_data) + self._auth(project_credentials.dict()) storage_client = storage.Client(credentials=self.credentials) storage_client.list_buckets(max_results=1) except Exception: @@ -137,73 +140,63 @@ def configure_project(self, config_data: Dict) -> GCPProjectValues: elif not project_values.default_credentials: self._raise_invalid_credentials_error(fields=[["credentials"]]) - project_values.area = self._get_hub_geographic_area(config_data.get("area")) + project_values.area = self._get_hub_geographic_area_element(project_config.area) location = self._get_location(project_values.area.selected) - project_values.region, regions = self._get_hub_region( + project_values.region, regions = self._get_hub_region_element( location=location, - default_region=config_data.get("region"), + selected=project_config.region, ) - project_values.zone = self._get_hub_zone( + project_values.zone = self._get_hub_zone_element( location=location, region=regions.get(project_values.region.selected), - default_zone=config_data.get("zone"), + selected=project_config.zone, ) - project_values.bucket_name = self._get_hub_buckets( + project_values.bucket_name = self._get_hub_buckets_element( region=project_values.region.selected, - default_bucket=config_data.get("bucket_name"), + selected=project_config.bucket_name, ) - project_values.vpc_subnet = self._get_hub_vpc_subnet( + project_values.vpc_subnet = self._get_hub_vpc_subnet_element( region=project_values.region.selected, - default_vpc=config_data.get("vpc"), - default_subnet=config_data.get("subnet"), + selected_vpc=project_config.vpc, + selected_subnet=project_config.subnet, ) return project_values - def create_config_auth_data_from_project_config( - self, project_config: GCPProjectConfigWithCreds - ) -> Tuple[Dict, Dict]: - config_data = GCPProjectConfig.parse_obj(project_config).dict() + def create_project(self, project_config: GCPProjectConfigWithCreds) -> Tuple[Dict, Dict]: auth_data = project_config.credentials.__root__.dict() + self._auth(auth_data) if project_config.credentials.__root__.type == "default": - self._auth(auth_data) service_account_email = self._get_or_create_service_account( f"{project_config.bucket_name}-sa" ) self._grant_roles_to_service_account(service_account_email) self._check_if_can_create_service_account_key(service_account_email) auth_data["service_account_email"] = service_account_email - return config_data, auth_data - - def get_backend_config_from_hub_config_data( - self, project_name: str, config_data: Dict, auth_data: Dict - ) -> GCPConfig: - self._auth(auth_data) - data = { - "backend": "gcp", - "credentials": auth_data, + config_data = { "project": self.project_id, - "region": config_data["region"], - "zone": config_data["zone"], - "bucket": config_data["bucket_name"], - "vpc": config_data["vpc"], - "subnet": config_data["subnet"], + "area": project_config.area, + "region": project_config.region, + "zone": project_config.zone, + "bucket_name": project_config.bucket_name, + "vpc": project_config.bucket_name, + "subnet": project_config.subnet, } - return GCPConfig.deserialize(data) + return config_data, auth_data - def get_project_config_from_project( + def get_project_config( self, project: Project, include_creds: bool ) -> Union[GCPProjectConfig, GCPProjectConfigWithCreds]: - json_config = json.loads(project.config) - area = json_config["area"] - region = json_config["region"] - zone = json_config["zone"] - bucket_name = json_config["bucket_name"] - vpc = json_config["vpc"] - subnet = json_config["subnet"] + config_data = json.loads(project.config) + area = config_data["area"] + region = config_data["region"] + zone = config_data["zone"] + bucket_name = config_data["bucket_name"] + vpc = config_data["vpc"] + subnet = config_data["subnet"] if include_creds: - json_auth = json.loads(project.auth) + auth_data = json.loads(project.auth) return GCPProjectConfigWithCreds( - credentials=GCPProjectCreds.parse_obj(json_auth), + credentials=GCPProjectCreds.parse_obj(auth_data), area=area, region=region, zone=zone, @@ -220,19 +213,38 @@ def get_project_config_from_project( subnet=subnet, ) - def _get_hub_geographic_area(self, default_area: Optional[str]) -> ProjectElement: + def get_backend(self, project: Project) -> GCPBackend: + config_data = json.loads(project.config) + auth_data = json.loads(project.auth) + project_id = config_data.get("project") + # Legacy config_data does not store project + if project_id is None: + self._auth(auth_data) + project_id = self.project_id + config = GCPConfig( + project_id=project_id, + region=config_data["region"], + zone=config_data["zone"], + bucket_name=config_data["bucket_name"], + vpc=config_data["vpc"], + subnet=config_data["subnet"], + credentials=auth_data, + ) + return GCPBackend(config) + + def _get_hub_geographic_area_element(self, selected: Optional[str]) -> ProjectElement: area_names = sorted([l["name"] for l in GCP_LOCATIONS]) - if default_area is None: - default_area = DEFAULT_GEOGRAPHIC_AREA - if default_area not in area_names: - raise BackendConfigError(f"Invalid GCP area {default_area}") - element = ProjectElement(selected=default_area) + if selected is None: + selected = DEFAULT_GEOGRAPHIC_AREA + if selected not in area_names: + raise BackendConfigError(f"Invalid GCP area {selected}") + element = ProjectElement(selected=selected) for area_name in area_names: element.values.append(ProjectElementValue(value=area_name, label=area_name)) return element - def _get_hub_region( - self, location: Dict, default_region: Optional[str] + def _get_hub_region_element( + self, location: Dict, selected: Optional[str] ) -> Tuple[ProjectElement, Dict]: regions_client = compute_v1.RegionsClient(credentials=self.credentials) regions = regions_client.list(project=self.project_id) @@ -240,13 +252,11 @@ def _get_hub_region( [r.name for r in regions if r.name in location["regions"]], key=lambda name: (name != location["default_region"], name), ) - if default_region is None: - default_region = region_names[0] - if default_region not in region_names: - raise BackendConfigError( - f"Invalid GCP region {default_region} in area {location['name']}" - ) - element = ProjectElement(selected=default_region) + if selected is None: + selected = region_names[0] + if selected not in region_names: + raise BackendConfigError(f"Invalid GCP region {selected} in area {location['name']}") + element = ProjectElement(selected=selected) for region_name in region_names: element.values.append(ProjectElementValue(value=region_name, label=region_name)) return element, {r.name: r for r in regions} @@ -257,49 +267,49 @@ def _get_location(self, area: str) -> Optional[Dict]: return location return None - def _get_hub_zone( - self, location: Dict, region: compute_v1.Region, default_zone: Optional[str] + def _get_hub_zone_element( + self, location: Dict, region: compute_v1.Region, selected: Optional[str] ) -> ProjectElement: zone_names = sorted( [gcp_utils.get_resource_name(z) for z in region.zones], key=lambda name: (name != location["default_zone"], name), ) - if default_zone is None: - default_zone = zone_names[0] - if default_zone not in zone_names: - raise BackendConfigError(f"Invalid GCP zone {default_zone} in region {region.name}") - element = ProjectElement(selected=default_zone) + if selected is None: + selected = zone_names[0] + if selected not in zone_names: + raise BackendConfigError(f"Invalid GCP zone {selected} in region {region.name}") + element = ProjectElement(selected=selected) for zone_name in zone_names: element.values.append(ProjectElementValue(value=zone_name, label=zone_name)) return element - def _get_hub_buckets( - self, region: str, default_bucket: Optional[str] = None + def _get_hub_buckets_element( + self, region: str, selected: Optional[str] = None ) -> ProjectElement: storage_client = storage.Client(credentials=self.credentials) buckets = storage_client.list_buckets() bucket_names = [bucket.name for bucket in buckets if bucket.location.lower() == region] - if default_bucket is not None and default_bucket not in bucket_names: + if selected is not None and selected not in bucket_names: raise BackendConfigError( - f"Invalid bucket {default_bucket} for region {region}", + f"Invalid bucket {selected} for region {region}", code="invalid_bucket", fields=[["bucket_name"]], ) - element = ProjectElement(selected=default_bucket) + element = ProjectElement(selected=selected) for bucket_name in bucket_names: element.values.append(ProjectElementValue(value=bucket_name, label=bucket_name)) return element - def _get_hub_vpc_subnet( + def _get_hub_vpc_subnet_element( self, region: str, - default_vpc: Optional[str], - default_subnet: Optional[str], + selected_vpc: Optional[str], + selected_subnet: Optional[str], ) -> GCPVPCSubnetProjectElement: - if default_vpc is None: - default_vpc = "default" - if default_subnet is None: - default_subnet = "default" + if selected_vpc is None: + selected_vpc = "default" + if selected_subnet is None: + selected_subnet = "default" no_preference_vpc_subnet = ("default", "default") networks_client = compute_v1.NetworksClient(credentials=self.credentials) networks = networks_client.list(project=self.project_id) @@ -310,10 +320,10 @@ def _get_hub_vpc_subnet( if subnet_region != region: continue vpc_subnet_list.append((network.name, gcp_utils.get_subnet_name(subnet))) - if (default_vpc, default_subnet) not in vpc_subnet_list: - raise BackendConfigError(f"Invalid VPC subnet {default_vpc, default_subnet}") - if (default_vpc, default_subnet) != no_preference_vpc_subnet: - selected = f"{default_subnet} ({default_vpc})" + if (selected_vpc, selected_subnet) not in vpc_subnet_list: + raise BackendConfigError(f"Invalid VPC subnet {selected_vpc, selected_subnet}") + if (selected_vpc, selected_subnet) != no_preference_vpc_subnet: + selected = f"{selected_subnet} ({selected_vpc})" else: selected = f"No preference (default)" vpc_subnet_list = sorted(vpc_subnet_list, key=lambda t: t != no_preference_vpc_subnet) diff --git a/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py b/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py index 5648fed30..828b66f09 100644 --- a/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py @@ -27,38 +27,37 @@ ProjectElementValue, ProjectValues, ) -from dstack._internal.hub.services.backends.base import BackendConfigError +from dstack._internal.hub.services.backends.base import BackendConfigError, Configurator -class LambdaConfigurator: +class LambdaConfigurator(Configurator): NAME = "lambda" - def get_backend_class(self) -> type: - return LambdaBackend - - def configure_project(self, config: LambdaProjectConfigWithCredsPartial) -> ProjectValues: + def configure_project( + self, project_config: LambdaProjectConfigWithCredsPartial + ) -> ProjectValues: selected_storage_backend = None - if config.storage_backend is not None: - selected_storage_backend = config.storage_backend.type + if project_config.storage_backend is not None: + selected_storage_backend = project_config.storage_backend.type storage_backend_type = self._get_storage_backend_type_element( selected=selected_storage_backend ) project_values = LambdaProjectValues(storage_backend_type=storage_backend_type) - if config.api_key is None: + if project_config.api_key is None: return project_values - self._validate_lambda_api_key(api_key=config.api_key) - if config.storage_backend is None or config.storage_backend.credentials is None: + self._validate_lambda_api_key(api_key=project_config.api_key) + if ( + project_config.storage_backend is None + or project_config.storage_backend.credentials is None + ): return project_values project_values.storage_backend_values = self._get_aws_storage_backend_values( - config.storage_backend + project_config.storage_backend ) return project_values - def create_config_auth_data_from_project_config( - self, project_config: LambdaProjectConfigWithCreds - ) -> Tuple[Dict, Dict]: + def create_project(self, project_config: LambdaProjectConfigWithCreds) -> Tuple[Dict, Dict]: config_data = { - "backend": "lambda", "storage_backend": self._get_aws_storage_backend_config_data( project_config.storage_backend ), @@ -69,7 +68,7 @@ def create_config_auth_data_from_project_config( } return config_data, auth_data - def get_project_config_from_project( + def get_project_config( self, project: Project, include_creds: bool ) -> Union[LambdaProjectConfig, LambdaProjectConfigWithCreds]: config_data = json.loads(project.config) @@ -90,10 +89,10 @@ def get_project_config_from_project( ) ) - def get_backend_config_from_hub_config_data( - self, project_name: str, config_data: Dict, auth_data: Dict - ) -> LambdaConfig: - return LambdaConfig( + def get_backend(self, project: Project) -> LambdaBackend: + config_data = json.loads(project.config) + auth_data = json.loads(project.auth) + config = LambdaConfig( api_key=auth_data["api_key"], storage_config=AWSStorageConfig( bucket=config_data["storage_backend"]["bucket"], @@ -104,6 +103,7 @@ def get_backend_config_from_hub_config_data( ), ), ) + return LambdaBackend(config) def _get_storage_backend_type_element(self, selected: Optional[str]) -> ProjectElement: element = ProjectElement( @@ -175,7 +175,7 @@ def _get_aws_storage_backend_config_data( aws_secret_access_key=config.credentials.secret_key, ) return { - "backend": "aws", + "type": "aws", "bucket": config.bucket_name, "region": self._get_aws_bucket_region(session, config.bucket_name), } diff --git a/cli/dstack/_internal/hub/services/backends/local/configurator.py b/cli/dstack/_internal/hub/services/backends/local/configurator.py index d41c4e242..a8cdda398 100644 --- a/cli/dstack/_internal/hub/services/backends/local/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/local/configurator.py @@ -10,26 +10,16 @@ class LocalConfigurator(Configurator): NAME = "local" - def get_backend_class(self) -> type: - return LocalBackend - - def configure_project(self, config_data: Dict) -> ProjectValues: + def configure_project(self, project_config: LocalProjectConfig) -> ProjectValues: return None - def get_backend_config_from_hub_config_data( - self, project_name: str, config_data: Dict, auth_data: Dict - ) -> LocalConfig: - return LocalConfig(namespace=project_name) - - def create_config_auth_data_from_project_config( - self, project_config: LocalProjectConfig - ) -> Tuple[Dict, Dict]: + def create_project(self, project_config: LocalProjectConfig) -> Tuple[Dict, Dict]: return {}, {} - def get_project_config_from_project( - self, project: Project, include_creds: bool - ) -> LocalProjectConfig: - config = self.get_backend_config_from_hub_config_data( - project.name, project.config, project.auth - ) + def get_project_config(self, project: Project, include_creds: bool) -> LocalProjectConfig: + config = LocalConfig(namespace=project.name) return LocalProjectConfig(path=str(config.backend_dir)) + + def get_backend(self, project: Project) -> LocalBackend: + config = LocalConfig(namespace=project.name) + return LocalBackend(config) From a67b62170e44506b419a9bc699c1d85c633de8bb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 28 Jun 2023 16:21:24 +0500 Subject: [PATCH 12/16] Fix runner version pin --- .../_internal/backend/lambdalabs/__init__.py | 7 +- .../_internal/backend/lambdalabs/compute.py | 90 +++++++++++++++---- 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/cli/dstack/_internal/backend/lambdalabs/__init__.py b/cli/dstack/_internal/backend/lambdalabs/__init__.py index ddeffb0d3..76afee236 100644 --- a/cli/dstack/_internal/backend/lambdalabs/__init__.py +++ b/cli/dstack/_internal/backend/lambdalabs/__init__.py @@ -56,12 +56,7 @@ def __init__( @classmethod def load(cls) -> Optional["LambdaBackend"]: - config = AWSConfig.load() - if config is None: - return None - return cls( - backend_config=config, - ) + return None def _s3_client(self) -> BaseClient: return self._get_client("s3") diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py index 25d440d9f..e768e773c 100644 --- a/cli/dstack/_internal/backend/lambdalabs/compute.py +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -10,7 +10,7 @@ import yaml from dstack import version -from dstack._internal.backend.base.compute import WS_PORT, NoCapacityError, choose_instance_type +from dstack._internal.backend.base.compute import WS_PORT, choose_instance_type from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml from dstack._internal.backend.lambdalabs.api_client import LambdaAPIClient @@ -25,6 +25,75 @@ _WAIT_FOR_INSTANCE_INTERVAL = 10 +_INSTANCE_TYPE_TO_GPU_DATA_MAP = { + "gpu_1x_h100_pcie": { + "name": "H100", + "count": 1, + "memory_mib": 80 * 1024, + }, + "gpu_8x_a100_80gb_sxm4": { + "name": "A100", + "count": 8, + "memory_mib": 80 * 1024, + }, + "gpu_1x_a10": { + "name": "A10", + "count": 1, + "memory_mib": 24 * 1024, + }, + "gpu_1x_rtx6000": { + "name": "RTX6000", + "count": 1, + "memory_mib": 24 * 1024, + }, + "gpu_1x_a100": { + "name": "A100", + "count": 1, + "memory_mib": 40 * 1024, + }, + "gpu_1x_a100_sxm4": { + "name": "A100", + "count": 1, + "memory_mib": 40 * 1024, + }, + "gpu_2x_a100": { + "name": "A100", + "count": 2, + "memory_mib": 40 * 1024, + }, + "gpu_4x_a100": { + "name": "A100", + "count": 4, + "memory_mib": 40 * 1024, + }, + "gpu_8x_a100": { + "name": "A100", + "count": 8, + "memory_mib": 40 * 1024, + }, + "gpu_1x_a6000": { + "name": "A6000", + "count": 1, + "memory_mib": 48 * 1024, + }, + "gpu_2x_a6000": { + "name": "A6000", + "count": 2, + "memory_mib": 48 * 1024, + }, + "gpu_4x_a6000": { + "name": "A6000", + "count": 4, + "memory_mib": 48 * 1024, + }, + "gpu_8x_v100": { + "name": "V100", + "count": 8, + "memory_mib": 16 * 1024, + }, +} + + class LambdaCompute: def __init__(self, lambda_config: LambdaConfig): self.lambda_config = lambda_config @@ -106,20 +175,6 @@ def _instance_type_data_to_instance_type(instance_type_data: Dict) -> Optional[I ) -_INSTANCE_TYPE_TO_GPU_DATA_MAP = { - "gpu_1x_a10": { - "name": "A10", - "count": 1, - "memory_mib": 24 * 1024, - }, - "gpu_1x_rtx6000": { - "name": "RTX6000", - "count": 1, - "memory_mib": 24 * 1024, - }, -} - - def _get_instance_type_gpus(instance_type_name: str) -> Optional[List[Gpu]]: gpu_data = _INSTANCE_TYPE_TO_GPU_DATA_MAP.get(instance_type_name) if gpu_data is None: @@ -165,7 +220,6 @@ def _run_instance( file_system_names=[], ) instance_id = instances_ids[0] - # instance_id = api_client.list_instances()[0]["id"] instance_info = _wait_for_instance(api_client, instance_id) thread = Thread( target=_start_runner, @@ -236,8 +290,8 @@ def _setup_instance(hostname: str, user_ssh_key: str): def _get_setup_command() -> str: if not version.__is_release__: # TODO: replace with latest after merge - return "ENVIRONMENT=stage RUNNER=1351 .dstack/setup_lambda.sh" - return f"ENVIRONMENT=prod RUNNER={version.__version__} .dstack/setup_lambda.sh" + return "ENVIRONMENT=stage RUNNER_VERSION=1351 .dstack/setup_lambda.sh" + return f"ENVIRONMENT=prod RUNNER_VERSION={version.__version__} .dstack/setup_lambda.sh" def _launch_runner(hostname: str, launch_script: str): From ae46bb729ab444e6f77fc8c919f76a97f88409f1 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 28 Jun 2023 16:21:56 +0500 Subject: [PATCH 13/16] Refactor dstack init --- .../_internal/cli/commands/init/__init__.py | 31 ++----------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/cli/dstack/_internal/cli/commands/init/__init__.py b/cli/dstack/_internal/cli/commands/init/__init__.py index 4b15c8c4c..2453ae3c8 100644 --- a/cli/dstack/_internal/cli/commands/init/__init__.py +++ b/cli/dstack/_internal/cli/commands/init/__init__.py @@ -1,12 +1,8 @@ -import os from argparse import Namespace from pathlib import Path from typing import Optional import giturlparse -from cryptography.hazmat.backends import default_backend as crypto_default_backend -from cryptography.hazmat.primitives import serialization as crypto_serialization -from cryptography.hazmat.primitives.asymmetric import rsa from git.exc import InvalidGitRepositoryError from dstack._internal.api.repos import InvalidRepoCredentialsError, get_local_repo_credentials @@ -16,6 +12,7 @@ from dstack._internal.cli.errors import CLIError from dstack._internal.core.repo import LocalRepo, RemoteRepo from dstack._internal.core.userconfig import RepoUserConfig +from dstack._internal.utils.crypto import generage_rsa_key_pair class InitCommand(BasicCommand): @@ -106,29 +103,5 @@ def get_ssh_keypair( if dstack_key_path is None: return None if not dstack_key_path.exists(): - key = rsa.generate_private_key( - backend=crypto_default_backend(), public_exponent=65537, key_size=2048 - ) - - def key_opener(path, flags): - return os.open(path, flags, 0o600) - - with open(dstack_key_path, "wb", opener=key_opener) as f: - f.write( - key.private_bytes( - crypto_serialization.Encoding.PEM, - crypto_serialization.PrivateFormat.PKCS8, - crypto_serialization.NoEncryption(), - ) - ) - with open( - dstack_key_path.with_suffix(dstack_key_path.suffix + ".pub"), "wb", opener=key_opener - ) as f: - f.write( - key.public_key().public_bytes( - crypto_serialization.Encoding.OpenSSH, - crypto_serialization.PublicFormat.OpenSSH, - ) - ) - f.write(b" dstack\n") + generage_rsa_key_pair(private_key_path=dstack_key_path) return str(dstack_key_path) From 02e4acf945beba16167be0cd939d557c1fa81147 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 29 Jun 2023 09:22:37 +0500 Subject: [PATCH 14/16] Support dstack[lambda] install --- setup.py | 135 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 73 insertions(+), 62 deletions(-) diff --git a/setup.py b/setup.py index 1b636e905..69d4b4f02 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,74 @@ def get_long_description(): ) +BASE_DEPS = [ + "pyyaml", + "requests", + "gitpython", + "tqdm", + "jsonschema", + "python-dateutil", + "paramiko", + "git-url-parse", + "rich", + "rich-argparse", + "fastapi", + "starlette>=0.26.0", + "uvicorn", + "pydantic", + "sqlalchemy[asyncio]>=2.0.0", + "websocket-client", + "cursor", + "simple-term-menu", + "py-cpuinfo", + "pygtail", + "packaging", + "aiosqlite", + "apscheduler", + "alembic>=1.10.2", + "typing-extensions>=4.0.0", + "file-read-backwards>=3.0.0", + "psutil>=5.0.0", + "cryptography", + "grpcio>=1.50", # indirect + "filelock", + "watchfiles", +] + +AWS_DEPS = [ + "boto3", + "botocore", +] + +AZURE_DEPS = [ + "azure-identity>=1.12.0", + "azure-keyvault-secrets>=4.6.0", + "azure-storage-blob>=12.15.0", + "azure-monitor-query>=1.2.0", + "azure-mgmt-subscription>=3.1.1", + "azure-mgmt-compute>=29.1.0", + "azure-mgmt-network==23.0.0b2", + "azure-mgmt-resource>=22.0.0", + "azure-mgmt-authorization>=3.0.0", + "azure-mgmt-storage>=21.0.0", + "azure-mgmt-keyvault>=10.1.0", + "azure-mgmt-loganalytics==13.0.0b6", + "azure-mgmt-msi", + "azure-mgmt-monitor", + "azure-graphrbac", +] + +GCP_DEPS = [ + "google-auth>=2.3.0", # indirect + "google-cloud-storage>=2.0.0", + "google-cloud-compute>=1.5.0", + "google-cloud-secret-manager>=2.0.0", + "google-cloud-logging>=2.0.0", + "google-api-python-client>=2.80.0", +] + +LAMBDA_DEPS = AWS_DEPS + setup( name="dstack", version=get_version(), @@ -56,69 +124,12 @@ def get_long_description(): long_description=get_long_description(), long_description_content_type="text/markdown", python_requires=">=3.7", - install_requires=[ - "pyyaml", - "requests", - "gitpython", - "tqdm", - "jsonschema", - "python-dateutil", - "paramiko", - "git-url-parse", - "rich", - "rich-argparse", - "fastapi", - "starlette>=0.26.0", - "uvicorn", - "pydantic", - "sqlalchemy[asyncio]>=2.0.0", - "websocket-client", - "cursor", - "simple-term-menu", - "py-cpuinfo", - "pygtail", - "packaging", - "aiosqlite", - "apscheduler", - "alembic>=1.10.2", - "typing-extensions>=4.0.0", - "file-read-backwards>=3.0.0", - "psutil>=5.0.0", - "cryptography", - "grpcio>=1.50", # indirect - "filelock", - "watchfiles", - ], + install_requires=BASE_DEPS, extras_require={ - "aws": [ - "boto3", - "botocore", - ], - "azure": [ - "azure-identity>=1.12.0", - "azure-keyvault-secrets>=4.6.0", - "azure-storage-blob>=12.15.0", - "azure-monitor-query>=1.2.0", - "azure-mgmt-subscription>=3.1.1", - "azure-mgmt-compute>=29.1.0", - "azure-mgmt-network==23.0.0b2", - "azure-mgmt-resource>=22.0.0", - "azure-mgmt-authorization>=3.0.0", - "azure-mgmt-storage>=21.0.0", - "azure-mgmt-keyvault>=10.1.0", - "azure-mgmt-loganalytics==13.0.0b6", - "azure-mgmt-msi", - "azure-mgmt-monitor", - "azure-graphrbac", - ], - "gcp": [ - "google-auth>=2.3.0", # indirect - "google-cloud-storage>=2.0.0", - "google-cloud-compute>=1.5.0", - "google-cloud-secret-manager>=2.0.0", - "google-cloud-logging>=2.0.0", - "google-api-python-client>=2.80.0", - ], + "aws": AWS_DEPS, + "azure": AZURE_DEPS, + "gcp": GCP_DEPS, + "lambda": LAMBDA_DEPS, }, classifiers=[ "Development Status :: 2 - Pre-Alpha", From db7b1143947e03fc2a7d5a6608d0ed9b083fd742 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 29 Jun 2023 09:58:53 +0500 Subject: [PATCH 15/16] Support lambda regions config --- .../_internal/backend/lambdalabs/compute.py | 19 +++++--- .../_internal/backend/lambdalabs/config.py | 3 ++ cli/dstack/_internal/hub/models/__init__.py | 8 ++++ .../backends/lambdalabs/configurator.py | 44 +++++++++++++++++-- 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py index e768e773c..e77e494c7 100644 --- a/cli/dstack/_internal/backend/lambdalabs/compute.py +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -109,7 +109,7 @@ def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead: ) def get_instance_type(self, job: Job) -> Optional[InstanceType]: - instance_types = _list_instance_types(self.api_client) + instance_types = _list_instance_types(self.api_client, self.lambda_config.regions) return choose_instance_type( instance_types=instance_types, requirements=job.requirements, @@ -118,7 +118,7 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]: def run_instance(self, job: Job, instance_type: InstanceType) -> str: return _run_instance( api_client=self.api_client, - region_name=instance_type.available_regions[0], + region_name=_get_instance_region(instance_type, self.lambda_config.regions), instance_type_name=instance_type.instance_name, user_ssh_key=job.ssh_key_pub, hub_ssh_key=get_hub_ssh_public_key(), @@ -133,12 +133,14 @@ def cancel_spot_request(self, request_id: str): pass -def _list_instance_types(api_client: LambdaAPIClient) -> List[InstanceType]: +def _list_instance_types(api_client: LambdaAPIClient, regions: List[str]) -> List[InstanceType]: instance_types_data = api_client.list_instance_types() instance_types = [] for instance_type_data in instance_types_data.values(): instance_type = _instance_type_data_to_instance_type(instance_type_data) - if instance_type is not None: + if instance_type is None: + continue + if _get_instance_region(instance_type, regions) is not None: instance_types.append(instance_type) return instance_types @@ -156,8 +158,6 @@ def _instance_type_data_to_instance_type(instance_type_data: Dict) -> Optional[I instance_type = instance_type_data["instance_type"] regions_data = instance_type_data["regions_with_capacity_available"] regions = [r["name"] for r in regions_data] - if len(regions) == 0: - return None instance_type_specs = instance_type["specs"] gpus = _get_instance_type_gpus(instance_type["name"]) if gpus is None: @@ -185,6 +185,13 @@ def _get_instance_type_gpus(instance_type_name: str) -> Optional[List[Gpu]]: ] +def _get_instance_region(instance_type: InstanceType, regions: List[str]) -> Optional[str]: + for r in instance_type.available_regions: + if r in regions: + return r + return None + + def _get_instance_name(job: Job) -> str: return f"dstack-{job.run_name}" diff --git a/cli/dstack/_internal/backend/lambdalabs/config.py b/cli/dstack/_internal/backend/lambdalabs/config.py index 55770aae0..b48b29737 100644 --- a/cli/dstack/_internal/backend/lambdalabs/config.py +++ b/cli/dstack/_internal/backend/lambdalabs/config.py @@ -1,3 +1,5 @@ +from typing import List + from pydantic import BaseModel from typing_extensions import Literal @@ -16,5 +18,6 @@ class AWSStorageConfig(BaseModel): class LambdaConfig(BaseModel): backend: Literal["lambda"] = "lambda" + regions: List[str] api_key: str storage_config: AWSStorageConfig diff --git a/cli/dstack/_internal/hub/models/__init__.py b/cli/dstack/_internal/hub/models/__init__.py index 5e4a402e7..15dd760e0 100644 --- a/cli/dstack/_internal/hub/models/__init__.py +++ b/cli/dstack/_internal/hub/models/__init__.py @@ -182,11 +182,13 @@ class AWSStorageProjectConfigWithCreds(AWSStorageProjectConfig): class LambdaProjectConfigWithCredsPartial(BaseModel): type: Literal["lambda"] = "lambda" api_key: Optional[str] + regions: Optional[List[str]] storage_backend: Optional[AWSStorageProjectConfigWithCredsPartial] class LambdaProjectConfig(BaseModel): type: Literal["lambda"] = "lambda" + regions: List[str] storage_backend: AWSStorageProjectConfig @@ -352,6 +354,11 @@ class ProjectElement(BaseModel): values: List[ProjectElementValue] = [] +class ProjectMultiElement(BaseModel): + selected: List[str] + values: List[ProjectElementValue] = [] + + class AWSBucketProjectElementValue(BaseModel): name: str created: str @@ -409,6 +416,7 @@ class AWSStorageBackendValues(BaseModel): class LambdaProjectValues(BaseModel): type: Literal["lambda"] = "lambda" storage_backend_type: ProjectElement + regions: Optional[ProjectMultiElement] storage_backend_values: Optional[AWSStorageBackendValues] diff --git a/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py b/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py index 828b66f09..95f3c34bd 100644 --- a/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/lambdalabs/configurator.py @@ -1,5 +1,5 @@ import json -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import botocore from boto3.session import Session @@ -25,17 +25,33 @@ LambdaProjectValues, ProjectElement, ProjectElementValue, - ProjectValues, + ProjectMultiElement, ) from dstack._internal.hub.services.backends.base import BackendConfigError, Configurator +REGIONS = [ + "us-south-1", + "us-west-2", + "us-west-1", + "us-midwest-1", + "us-west-3", + "us-east-1", + "australia-southeast-1", + "europe-central-1", + "asia-south-1", + "me-west-1", + "europe-south-1", + "asia-northeast-1", + "asia-northeast-1", +] + class LambdaConfigurator(Configurator): NAME = "lambda" def configure_project( self, project_config: LambdaProjectConfigWithCredsPartial - ) -> ProjectValues: + ) -> LambdaProjectValues: selected_storage_backend = None if project_config.storage_backend is not None: selected_storage_backend = project_config.storage_backend.type @@ -46,6 +62,7 @@ def configure_project( if project_config.api_key is None: return project_values self._validate_lambda_api_key(api_key=project_config.api_key) + project_values.regions = self._get_regions_element(selected=project_config.regions) if ( project_config.storage_backend is None or project_config.storage_backend.credentials is None @@ -58,6 +75,7 @@ def configure_project( def create_project(self, project_config: LambdaProjectConfigWithCreds) -> Tuple[Dict, Dict]: config_data = { + "regions": project_config.regions, "storage_backend": self._get_aws_storage_backend_config_data( project_config.storage_backend ), @@ -75,6 +93,7 @@ def get_project_config( if include_creds: auth_data = json.loads(project.auth) return LambdaProjectConfigWithCreds( + regions=config_data["regions"], api_key=auth_data["api_key"], storage_backend=AWSStorageProjectConfigWithCreds( bucket_name=config_data["storage_backend"]["bucket"], @@ -84,15 +103,17 @@ def get_project_config( ), ) return LambdaProjectConfig( + regions=config_data["regions"], storage_backend=AWSStorageProjectConfig( bucket_name=config_data["storage_backend"]["bucket"] - ) + ), ) def get_backend(self, project: Project) -> LambdaBackend: config_data = json.loads(project.config) auth_data = json.loads(project.auth) config = LambdaConfig( + regions=config_data["regions"], api_key=auth_data["api_key"], storage_config=AWSStorageConfig( bucket=config_data["storage_backend"]["bucket"], @@ -124,6 +145,21 @@ def _validate_lambda_api_key(self, api_key: str): ) raise e + def _get_regions_element(self, selected: Optional[List[str]]) -> ProjectMultiElement: + if selected is not None: + for r in selected: + if r not in REGIONS: + raise BackendConfigError( + "Invalid regions", + code="invalid_regions", + fields=[["regions"]], + ) + element = ProjectMultiElement( + selected=selected or REGIONS, + values=[ProjectElementValue(value=r, label=r) for r in REGIONS], + ) + return element + def _get_aws_storage_backend_values( self, config: AWSStorageProjectConfigWithCredsPartial ) -> AWSStorageBackendValues: From a350d17be2976ab4e83a1b839bc3dcda0300fc2a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 29 Jun 2023 10:46:57 +0500 Subject: [PATCH 16/16] Support lambda in PythonAPI --- cli/dstack/_internal/api/workflow_api.py | 7 +++++++ .../_internal/backend/lambdalabs/__init__.py | 6 ++++-- .../_internal/backend/lambdalabs/config.py | 18 +++++++++++++++--- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/cli/dstack/_internal/api/workflow_api.py b/cli/dstack/_internal/api/workflow_api.py index 431447158..096ecb920 100644 --- a/cli/dstack/_internal/api/workflow_api.py +++ b/cli/dstack/_internal/api/workflow_api.py @@ -58,6 +58,13 @@ def get_current_backend() -> Backend: raise DstackError( "Dependencies for GCP backend are not installed. Run `pip install dstack[gcp]`." ) + elif current_backend_type == "lambda": + try: + from dstack._internal.backend.lambdalabs import LambdaBackend as backend_class + except ImportError: + raise DstackError( + "Dependencies for LambdaLabs backend are not installed. Run `pip install dstack[lambda]`." + ) elif current_backend_type == "local": from dstack._internal.backend.local import LocalBackend as backend_class current_backend = backend_class.load() diff --git a/cli/dstack/_internal/backend/lambdalabs/__init__.py b/cli/dstack/_internal/backend/lambdalabs/__init__.py index 76afee236..e8416e16d 100644 --- a/cli/dstack/_internal/backend/lambdalabs/__init__.py +++ b/cli/dstack/_internal/backend/lambdalabs/__init__.py @@ -5,7 +5,6 @@ from botocore.client import BaseClient from dstack._internal.backend.aws import logs -from dstack._internal.backend.aws.config import AWSConfig from dstack._internal.backend.aws.secrets import AWSSecretsManager from dstack._internal.backend.aws.storage import AWSStorage from dstack._internal.backend.base import Backend @@ -56,7 +55,10 @@ def __init__( @classmethod def load(cls) -> Optional["LambdaBackend"]: - return None + config = LambdaConfig.load() + if config is None: + return None + return cls(config) def _s3_client(self) -> BaseClient: return self._get_client("s3") diff --git a/cli/dstack/_internal/backend/lambdalabs/config.py b/cli/dstack/_internal/backend/lambdalabs/config.py index b48b29737..6bd95137c 100644 --- a/cli/dstack/_internal/backend/lambdalabs/config.py +++ b/cli/dstack/_internal/backend/lambdalabs/config.py @@ -1,8 +1,10 @@ -from typing import List +from typing import Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from typing_extensions import Literal +from dstack._internal.backend.base.config import BackendConfig + class AWSStorageConfigCredentials(BaseModel): access_key: str @@ -16,8 +18,18 @@ class AWSStorageConfig(BaseModel): credentials: AWSStorageConfigCredentials -class LambdaConfig(BaseModel): +class LambdaConfig(BackendConfig, BaseModel): backend: Literal["lambda"] = "lambda" regions: List[str] api_key: str storage_config: AWSStorageConfig + + def serialize(self) -> Dict: + return self.dict() + + @classmethod + def deserialize(cls, config_data: Dict) -> Optional["LambdaConfig"]: + try: + return LambdaConfig.parse_obj(config_data) + except ValidationError: + return None