Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontend&backend): Add UI support for object store customization and prefixes #10787

Merged
merged 8 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions backend/src/v2/component/importer_launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package component

import (
"context"
"encoding/json"
"fmt"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"

pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata"

Expand Down Expand Up @@ -225,6 +227,10 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact

state := pb.Artifact_LIVE

provider, err := objectstore.ParseProviderFromPath(artifactUri)
if err != nil {
return nil, fmt.Errorf("No Provider scheme found in artifact Uri: %s", artifactUri)
}
artifact = &pb.Artifact{
TypeId: &artifactTypeId,
State: &state,
Expand All @@ -241,6 +247,20 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact
artifact.CustomProperties[k] = value
}
}

// Assume all imported artifacts will rely on execution environment for store provider session info
storeSessionInfo := objectstore.SessionInfo{
Provider: provider,
Params: map[string]string{
"fromEnv": "true",
},
}
storeSessionInfoJSON, err := json.Marshal(storeSessionInfo)
if err != nil {
return nil, err
}
storeSessionInfoStr := string(storeSessionInfoJSON)
artifact.CustomProperties["store_session_info"] = metadata.StringValue(storeSessionInfoStr)
return artifact, nil
}

Expand Down
2 changes: 1 addition & 1 deletion backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec
if err != nil {
return nil, fmt.Errorf("failed to determine schema for output %q: %w", name, err)
}
mlmdArtifact, err := opts.metadataClient.RecordArtifact(ctx, name, schema, outputArtifact, pb.Artifact_LIVE)
mlmdArtifact, err := opts.metadataClient.RecordArtifact(ctx, name, schema, outputArtifact, pb.Artifact_LIVE, opts.bucketConfig)
if err != nil {
return nil, metadataErr(err)
}
Expand Down
3 changes: 1 addition & 2 deletions backend/src/v2/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,10 @@ func InPodName() (string, error) {
}

func (c *Config) GetStoreSessionInfo(path string) (objectstore.SessionInfo, error) {
bucketConfig, err := objectstore.ParseBucketPathToConfig(path)
provider, err := objectstore.ParseProviderFromPath(path)
if err != nil {
return objectstore.SessionInfo{}, err
}
provider := strings.TrimSuffix(bucketConfig.Scheme, "://")
bucketProviders, err := c.getBucketProviders()
if err != nil {
return objectstore.SessionInfo{}, err
Expand Down
53 changes: 34 additions & 19 deletions backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package metadata

import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -90,7 +92,7 @@ type ClientInterface interface {
GetArtifactName(ctx context.Context, artifactId int64) (string, error)
GetArtifacts(ctx context.Context, ids []int64) ([]*pb.Artifact, error)
GetOutputArtifactsByExecutionId(ctx context.Context, executionId int64) (map[string]*OutputArtifact, error)
RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State) (*OutputArtifact, error)
RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State, bucketConfig *objectstore.Config) (*OutputArtifact, error)
GetOrInsertArtifactType(ctx context.Context, schema string) (typeID int64, err error)
FindMatchedArtifact(ctx context.Context, artifactToMatch *pb.Artifact, pipelineContextId int64) (matchedArtifact *pb.Artifact, err error)
}
Expand Down Expand Up @@ -301,11 +303,11 @@ func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace
}
glog.Infof("Pipeline Context: %+v", pipelineContext)
metadata := map[string]*pb.Value{
keyNamespace: stringValue(namespace),
keyResourceName: stringValue(runResource),
keyNamespace: StringValue(namespace),
keyResourceName: StringValue(runResource),
// pipeline root of this run
keyPipelineRoot: stringValue(GenerateOutputURI(pipelineRoot, []string{pipelineName, runID}, true)),
keyStoreSessionInfo: stringValue(storeSessionInfo),
keyPipelineRoot: StringValue(GenerateOutputURI(pipelineRoot, []string{pipelineName, runID}, true)),
keyStoreSessionInfo: StringValue(storeSessionInfo),
}
runContext, err := c.getOrInsertContext(ctx, runID, pipelineRunContextType, metadata)
glog.Infof("Pipeline Run Context: %+v", runContext)
Expand Down Expand Up @@ -401,7 +403,7 @@ func (c *Client) getExecutionTypeID(ctx context.Context, executionType *pb.Execu
return eType.GetTypeId(), nil
}

func stringValue(s string) *pb.Value {
func StringValue(s string) *pb.Value {
return &pb.Value{Value: &pb.Value_StringValue{StringValue: s}}
}

Expand Down Expand Up @@ -531,8 +533,8 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
TypeId: &typeID,
CustomProperties: map[string]*pb.Value{
// We should support overriding display name in the future, for now it defaults to task name.
keyDisplayName: stringValue(config.TaskName),
keyTaskName: stringValue(config.TaskName),
keyDisplayName: StringValue(config.TaskName),
keyTaskName: StringValue(config.TaskName),
},
}
if config.Name != "" {
Expand All @@ -555,15 +557,15 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config
e.CustomProperties[keyIterationCount] = intValue(int64(*config.IterationCount))
}
if config.ExecutionType == ContainerExecutionTypeName {
e.CustomProperties[keyPodName] = stringValue(config.PodName)
e.CustomProperties[keyPodUID] = stringValue(config.PodUID)
e.CustomProperties[keyNamespace] = stringValue(config.Namespace)
e.CustomProperties[keyImage] = stringValue(config.Image)
e.CustomProperties[keyPodName] = StringValue(config.PodName)
e.CustomProperties[keyPodUID] = StringValue(config.PodUID)
e.CustomProperties[keyNamespace] = StringValue(config.Namespace)
e.CustomProperties[keyImage] = StringValue(config.Image)
if config.CachedMLMDExecutionID != "" {
e.CustomProperties[keyCachedExecutionID] = stringValue(config.CachedMLMDExecutionID)
e.CustomProperties[keyCachedExecutionID] = StringValue(config.CachedMLMDExecutionID)
}
if config.FingerPrint != "" {
e.CustomProperties[keyCacheFingerPrint] = stringValue(config.FingerPrint)
e.CustomProperties[keyCacheFingerPrint] = StringValue(config.FingerPrint)
}
}
if config.InputParameters != nil {
Expand Down Expand Up @@ -623,9 +625,9 @@ func (c *Client) PrePublishExecution(ctx context.Context, execution *Execution,
if e.CustomProperties == nil {
e.CustomProperties = make(map[string]*pb.Value)
}
e.CustomProperties[keyPodName] = stringValue(config.PodName)
e.CustomProperties[keyPodUID] = stringValue(config.PodUID)
e.CustomProperties[keyNamespace] = stringValue(config.Namespace)
e.CustomProperties[keyPodName] = StringValue(config.PodName)
e.CustomProperties[keyPodUID] = StringValue(config.PodUID)
e.CustomProperties[keyNamespace] = StringValue(config.Namespace)
e.LastKnownState = pb.Execution_RUNNING.Enum()

_, err := c.svc.PutExecution(ctx, &pb.PutExecutionRequest{
Expand Down Expand Up @@ -889,7 +891,7 @@ func SchemaToArtifactType(schema string) (*pb.ArtifactType, error) {
}

// RecordArtifact ...
func (c *Client) RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State) (*OutputArtifact, error) {
func (c *Client) RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State, bucketConfig *objectstore.Config) (*OutputArtifact, error) {
artifact, err := toMLMDArtifact(runtimeArtifact)
if err != nil {
return nil, err
Expand All @@ -911,7 +913,20 @@ func (c *Client) RecordArtifact(ctx context.Context, outputName, schema string,
}
if _, ok := artifact.CustomProperties["display_name"]; !ok {
// display name default value
artifact.CustomProperties["display_name"] = stringValue(outputName)
artifact.CustomProperties["display_name"] = StringValue(outputName)
}

// An artifact can belong to an external store specified via kfp-launcher
// or via executor environment (e.g. IRSA)
// This allows us to easily identify where to locate the artifact both
// in user executor environment as well as in kfp ui
if _, ok := artifact.CustomProperties["store_session_info"]; !ok {
storeSessionInfoJSON, err1 := json.Marshal(bucketConfig.SessionInfo)
if err1 != nil {
return nil, err1
}
storeSessionInfoStr := string(storeSessionInfoJSON)
artifact.CustomProperties["store_session_info"] = StringValue(storeSessionInfoStr)
}

res, err := c.svc.PutArtifacts(ctx, &pb.PutArtifactsRequest{
Expand Down
3 changes: 2 additions & 1 deletion backend/src/v2/metadata/client_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package metadata

import (
"context"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"

"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata"
Expand Down Expand Up @@ -82,7 +83,7 @@ func (c *FakeClient) GetOutputArtifactsByExecutionId(ctx context.Context, execut
return nil, nil
}

func (c *FakeClient) RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State) (*OutputArtifact, error) {
func (c *FakeClient) RecordArtifact(ctx context.Context, outputName, schema string, runtimeArtifact *pipelinespec.RuntimeArtifact, state pb.Artifact_State, bucketConfig *objectstore.Config) (*OutputArtifact, error) {
return nil, nil
}

Expand Down
10 changes: 10 additions & 0 deletions backend/src/v2/objectstore/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ func ParseBucketConfigForArtifactURI(uri string) (*Config, error) {
}, nil
}

// ParseProviderFromPath prases the uri and returns the scheme, which is
// used as the Provider string
func ParseProviderFromPath(uri string) (string, error) {
bucketConfig, err := ParseBucketPathToConfig(uri)
if err != nil {
return "", err
}
return strings.TrimSuffix(bucketConfig.Scheme, "://"), nil
}

func MinioDefaultEndpoint() string {
// Discover minio-service in the same namespace by env var.
// https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables
Expand Down
6 changes: 4 additions & 2 deletions backend/src/v2/objectstore/object_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"k8s.io/client-go/kubernetes"
"os"
"path/filepath"
"regexp"
"strings"
)

Expand Down Expand Up @@ -261,12 +262,13 @@ func createS3BucketSession(ctx context.Context, namespace string, sessionInfo *S

// AWS Specific:
// Path-style S3 endpoints, which are commonly used, may fall into either of two subdomains:
// 1) s3.amazonaws.com
// 1) [https://]s3.amazonaws.com
// 2) s3.<AWS Region>.amazonaws.com
// for (1) the endpoint is not required, thus we skip it, otherwise the writer will fail to close due to region mismatch.
// https://aws.amazon.com/blogs/infrastructure-and-automation/best-practices-for-using-amazon-s3-endpoints-in-aws-cloudformation-templates/
// https://docs.aws.amazon.com/sdk-for-go/api/aws/session/
if strings.ToLower(params.Endpoint) != "s3.amazonaws.com" {
awsEndpoint, _ := regexp.MatchString(`^(https://)?s3.amazonaws.com`, strings.ToLower(params.Endpoint))
if !awsEndpoint {
config.Endpoint = aws.String(params.Endpoint)
}

Expand Down
16 changes: 8 additions & 8 deletions frontend/server/aws-helper.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import fetch from 'node-fetch';
import { awsInstanceProfileCredentials, isS3Endpoint } from './aws-helper';
import { awsInstanceProfileCredentials, isAWSS3Endpoint } from './aws-helper';

// mock node-fetch module
jest.mock('node-fetch');
Expand Down Expand Up @@ -107,30 +107,30 @@ describe('awsInstanceProfileCredentials', () => {

describe('isS3Endpoint', () => {
it('checks a valid s3 endpoint', () => {
expect(isS3Endpoint('s3.amazonaws.com')).toBe(true);
expect(isAWSS3Endpoint('s3.amazonaws.com')).toBe(true);
});

it('checks a valid s3 regional endpoint', () => {
expect(isS3Endpoint('s3.dualstack.us-east-1.amazonaws.com')).toBe(true);
expect(isAWSS3Endpoint('s3.dualstack.us-east-1.amazonaws.com')).toBe(true);
});

it('checks a valid s3 cn endpoint', () => {
expect(isS3Endpoint('s3.cn-north-1.amazonaws.com.cn')).toBe(true);
expect(isAWSS3Endpoint('s3.cn-north-1.amazonaws.com.cn')).toBe(true);
});

it('checks a valid s3 fips GovCloud endpoint', () => {
expect(isS3Endpoint('s3-fips.us-gov-west-1.amazonaws.com')).toBe(true);
expect(isAWSS3Endpoint('s3-fips.us-gov-west-1.amazonaws.com')).toBe(true);
});

it('checks a valid s3 PrivateLink endpoint', () => {
expect(isS3Endpoint('vpce-1a2b3c4d-5e6f.s3.us-east-1.vpce.amazonaws.com')).toBe(true);
expect(isAWSS3Endpoint('vpce-1a2b3c4d-5e6f.s3.us-east-1.vpce.amazonaws.com')).toBe(true);
});

it('checks an invalid s3 endpoint', () => {
expect(isS3Endpoint('amazonaws.com')).toBe(false);
expect(isAWSS3Endpoint('amazonaws.com')).toBe(false);
});

it('checks non-s3 endpoint', () => {
expect(isS3Endpoint('minio.kubeflow')).toBe(false);
expect(isAWSS3Endpoint('minio.kubeflow')).toBe(false);
});
});
2 changes: 1 addition & 1 deletion frontend/server/aws-helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async function getIAMInstanceProfile(): Promise<string | undefined> {
*
* @param endpoint minio endpoint to check.
*/
export function isS3Endpoint(endpoint: string = ''): boolean {
export function isAWSS3Endpoint(endpoint: string = ''): boolean {
return !!endpoint.match(/s3.{0,}\.amazonaws\.com\.?.{0,}/i);
}

Expand Down
Loading
Loading