Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions flytecopilot/cmd/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"time"

"github.com/golang/protobuf/proto"
"github.com/spf13/cobra"

"github.com/flyteorg/flyte/flytecopilot/data"
Expand Down Expand Up @@ -49,6 +50,11 @@ func GetUploadModeVals() []string {
}

func (d *DownloadOptions) Download(ctx context.Context) error {
inputInterface := &core.VariableMap{}
if err := proto.Unmarshal(d.inputInterface, inputInterface); err != nil {
logger.Warnf(ctx, "Bad input interface passed, failed to unmarshal err: %s", err)
}

if d.remoteOutputsPrefix == "" {
return fmt.Errorf("to-output-prefix is required")
}
Expand Down Expand Up @@ -77,7 +83,7 @@ func (d *DownloadOptions) Download(ctx context.Context) error {
childCtx, cancelFn = context.WithTimeout(ctx, d.timeout)
}
defer cancelFn()
err := dl.DownloadInputs(childCtx, storage.DataReference(d.remoteInputsPath), d.localDirectoryPath)
err := dl.DownloadInputs(childCtx, inputInterface, storage.DataReference(d.remoteInputsPath), d.localDirectoryPath)
if err != nil {
logger.Errorf(ctx, "Downloading failed, err %s", err)
return err
Expand Down Expand Up @@ -116,6 +122,6 @@ func NewDownloadCommand(opts *RootOptions) *cobra.Command {
downloadCmd.Flags().StringVarP(&downloadOpts.metadataFormat, "format", "m", core.DataLoadingConfig_JSON.String(), fmt.Sprintf("What should be the output format for the primitive and structured types. Options [%v]", GetFormatVals()))
downloadCmd.Flags().StringVarP(&downloadOpts.downloadMode, "download-mode", "d", core.IOStrategy_DOWNLOAD_EAGER.String(), fmt.Sprintf("Download mode to use. Options [%v]", GetDownloadModeVals()))
downloadCmd.Flags().DurationVarP(&downloadOpts.timeout, "timeout", "t", time.Hour*1, "Max time to allow for downloads to complete, default is 1H")
downloadCmd.Flags().BytesBase64VarP(&downloadOpts.inputInterface, "input-interface", "i", nil, "Input interface proto message - core.VariableMap, base64 encoced string")
downloadCmd.Flags().BytesBase64VarP(&downloadOpts.inputInterface, "input-interface", "i", nil, "Input interface proto message - core.VariableMap, base64 encoded string")
return downloadCmd
}
78 changes: 78 additions & 0 deletions flytecopilot/cmd/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"testing"

"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/clients/go/coreutils"
Expand Down Expand Up @@ -137,6 +138,83 @@ func TestDownloadOptions_Download(t *testing.T) {
assert.ElementsMatch(t, []string{"inputs.json", "inputs.pb", "x", "y", "blob"}, collectFile(tmpDir))
})

t.Run("primitiveAndBlobInputsWithFileExtension", func(t *testing.T) {
tmpDir, err := ioutil.TempDir(tmpFolderLocation, tmpPrefix)
assert.NoError(t, err)
defer func() {
assert.NoError(t, os.RemoveAll(tmpDir))
}()
dopts.localDirectoryPath = tmpDir

s := promutils.NewTestScope()
store, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, s.NewSubScope("storage"))
assert.NoError(t, err)
dopts.RootOptions = &RootOptions{
Scope: s,
Store: store,
}

iface := &core.VariableMap{
Variables: map[string]*core.Variable{
"blob": {
Type: &core.LiteralType{Type: &core.LiteralType_Blob{Blob: &core.BlobType{Dimensionality: core.BlobType_SINGLE, Format: "xyz", FileExtension: "xyz", EnableLegacyFilename: false}}},
},
"legacy_blob": {
Type: &core.LiteralType{Type: &core.LiteralType_Blob{Blob: &core.BlobType{Dimensionality: core.BlobType_SINGLE, Format: "xyz", FileExtension: "xyz", EnableLegacyFilename: true}}},
},
},
}
d, err := proto.Marshal(iface)
assert.NoError(t, err)
dopts.inputInterface = d

blobLoc := storage.DataReference("blob-loc")
br := bytes.NewBuffer([]byte("Hello World!"))
assert.NoError(t, store.WriteRaw(ctx, blobLoc, int64(br.Len()), storage.Options{}, br))
assert.NoError(t, store.WriteProtobuf(ctx, storage.DataReference(inputPath), storage.Options{}, &core.LiteralMap{
Literals: map[string]*core.Literal{
"x": coreutils.MustMakePrimitiveLiteral(1),
"y": coreutils.MustMakePrimitiveLiteral("hello"),
"blob": {Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Blob{
Blob: &core.Blob{
Uri: blobLoc.String(),
Metadata: &core.BlobMetadata{
Type: &core.BlobType{
Dimensionality: core.BlobType_SINGLE,
Format: "xyz",
FileExtension: "xyz",
EnableLegacyFilename: false,
},
},
},
},
},
}},
"legacy_blob": {Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Blob{
Blob: &core.Blob{
Uri: blobLoc.String(),
Metadata: &core.BlobMetadata{
Type: &core.BlobType{
Dimensionality: core.BlobType_SINGLE,
Format: "xyz",
FileExtension: "xyz",
EnableLegacyFilename: true,
},
},
},
},
},
}},
},
}))
assert.NoError(t, dopts.Download(ctx), "Download Operation failed")
assert.ElementsMatch(t, []string{"inputs.json", "inputs.pb", "x", "y", "blob.xyz", "legacy_blob", "legacy_blob.xyz"}, collectFile(tmpDir))
})

t.Run("primitiveAndMissingBlobInputs", func(t *testing.T) {
tmpDir, err := ioutil.TempDir(tmpFolderLocation, tmpPrefix)
assert.NoError(t, err)
Expand Down
72 changes: 63 additions & 9 deletions flytecopilot/data/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,42 @@ type Downloader struct {
mode core.IOStrategy_DownloadMode
}

// By default, blobs (FlyteFiles) were not and still are not written with a file extension.
// For example, a data: FlyteFile["csv"] blob should be written to "data", even though
// Format="csv".
//
// When FileExtension="" (the default), this old behavior is preserved.
//
// However, a data: Annotated[FlyteFile["csv"], FileDownloadConfig(file_extension="csv")]
// blob should be written to "data.csv" - both Format="csv" and FileExtension="csv" (new behavior).
//
// So when e.g. FileExtension="csv", the file is written to "data.csv".
// Also, when e.g. FileExtension="csv" and EnableLegacyFilename=true, the file is written to
// "data" and "data.csv" (partially new behavior, bridges the gap of backward compatibility).
func resolveVarFilenames(vars *core.VariableMap) (map[string][]string, error) {
varFilenames := make(map[string][]string, len(vars.GetVariables()))
for varName, variable := range vars.GetVariables() {
varType := variable.GetType()
switch varType.GetType().(type) {
case *core.LiteralType_Blob:
if varType.GetBlob().GetFileExtension() == "" {
varFilenames[varName] = append(varFilenames[varName], varName)
} else {
varFilenames[varName] = append(varFilenames[varName], varName+"."+varType.GetBlob().GetFileExtension())
if varType.GetBlob().GetEnableLegacyFilename() {
varFilenames[varName] = append(varFilenames[varName], varName)
}
}
case *core.LiteralType_Simple:
varFilenames[varName] = append(varFilenames[varName], varName)
default:
return nil, fmt.Errorf("currently CoPilot downloader does not support [%s], system error", varType)
}
}
logger.Infof(context.Background(), "varFilenames: %v", varFilenames)
return varFilenames, nil
}

// TODO add timeout and rate limit
// TODO use chunk to download
func (d Downloader) handleBlob(ctx context.Context, blob *core.Blob, toPath string) (interface{}, error) {
Expand Down Expand Up @@ -432,7 +468,7 @@ func (d Downloader) handleLiteral(ctx context.Context, lit *core.Literal, filePa
if err != nil {
return nil, nil, errors.Wrapf(err, "failed to create directory [%s]", filePath)
}
v, m, err := d.RecursiveDownload(ctx, lit.GetMap(), filePath, writeToFile)
v, m, err := d.RecursiveDownload(ctx, lit.GetMap(), filePath, make(map[string][]string), writeToFile)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -468,7 +504,7 @@ type downloadedResult struct {
v interface{}
}

func (d Downloader) RecursiveDownload(ctx context.Context, inputs *core.LiteralMap, dir string, writePrimitiveToFile bool) (VarMap, *core.LiteralMap, error) {
func (d Downloader) RecursiveDownload(ctx context.Context, inputs *core.LiteralMap, dir string, varFilenames map[string][]string, writePrimitiveToFile bool) (VarMap, *core.LiteralMap, error) {
childCtx, cancel := context.WithCancel(ctx)
defer cancel()
if inputs == nil || len(inputs.GetLiterals()) == 0 {
Expand All @@ -486,14 +522,26 @@ func (d Downloader) RecursiveDownload(ctx context.Context, inputs *core.LiteralM
}
logger.Infof(ctx, "read object at location [%s]", offloadedMetadataURI)
}
varPath := path.Join(dir, variable)
lit := literal
f[variable] = futures.NewAsyncFuture(childCtx, func(ctx2 context.Context) (interface{}, error) {
v, lit, err := d.handleLiteral(ctx2, lit, varPath, writePrimitiveToFile)
if err != nil {
return nil, err
var filenames []string
var resultLit *core.Literal
var resultV interface{}
var err error
if len(varFilenames[variable]) == 0 {
filenames = []string{variable}
} else {
filenames = varFilenames[variable]
}
return downloadedResult{lit: lit, v: v}, nil
for _, filename := range filenames {
varPath := path.Join(dir, filename)
// TODO: Refactor handleLiteral to accept a list of file paths and return a list of downloaded results
resultV, resultLit, err = d.handleLiteral(ctx2, lit, varPath, writePrimitiveToFile)
if err != nil {
return nil, err
}
}
return downloadedResult{lit: resultLit, v: resultV}, nil
})
}

Expand All @@ -520,7 +568,7 @@ func (d Downloader) RecursiveDownload(ctx context.Context, inputs *core.LiteralM
return vmap, m, nil
}

func (d Downloader) DownloadInputs(ctx context.Context, inputRef storage.DataReference, outputDir string) error {
func (d Downloader) DownloadInputs(ctx context.Context, vars *core.VariableMap, inputRef storage.DataReference, outputDir string) error {
logger.Infof(ctx, "Downloading inputs from [%s]", inputRef)
defer logger.Infof(ctx, "Exited downloading inputs from [%s]", inputRef)
if err := os.MkdirAll(outputDir, os.ModePerm); err != nil {
Expand All @@ -533,7 +581,13 @@ func (d Downloader) DownloadInputs(ctx context.Context, inputRef storage.DataRef
logger.Errorf(ctx, "Failed to download inputs from [%s], err [%s]", inputRef, err)
return errors.Wrapf(err, "failed to download input metadata message from remote store")
}
varMap, lMap, err := d.RecursiveDownload(ctx, inputs, outputDir, true)

varFilenames, err := resolveVarFilenames(vars)
if err != nil {
return errors.Wrapf(err, "failed to resolve variable filenames")
}

varMap, lMap, err := d.RecursiveDownload(ctx, inputs, outputDir, varFilenames, true)
if err != nil {
return errors.Wrapf(err, "failed to download input variable from remote store")
}
Expand Down
4 changes: 2 additions & 2 deletions flytecopilot/data/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func TestRecursiveDownload(t *testing.T) {
}
}()

varMap, lMap, err := d.RecursiveDownload(context.Background(), inputs, toPath, true)
varMap, lMap, err := d.RecursiveDownload(context.Background(), inputs, toPath, map[string][]string{}, true)
assert.NoError(t, err)
assert.NotNil(t, varMap)
assert.NotNil(t, lMap)
Expand Down Expand Up @@ -309,7 +309,7 @@ func TestRecursiveDownload(t *testing.T) {
}
}()

varMap, lMap, err := d.RecursiveDownload(context.Background(), inputs, toPath, true)
varMap, lMap, err := d.RecursiveDownload(context.Background(), inputs, toPath, map[string][]string{}, true)
assert.NoError(t, err)
assert.NotNil(t, varMap)
assert.NotNil(t, lMap)
Expand Down
4 changes: 2 additions & 2 deletions flytecopilot/data/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ func (u Uploader) handleBlobType(ctx context.Context, localPath string, toPath s
}
}

return coreutils.MakeLiteralForBlob(toPath, false, ""), nil
return coreutils.MakeLiteralForBlob(toPath, false, "", "", false), nil
}
size := info.Size()
// Should we make this a go routine as well, so that we can introduce timeouts
return coreutils.MakeLiteralForBlob(toPath, false, ""), UploadFileToStorage(ctx, fpath, toPath, size, u.store)
return coreutils.MakeLiteralForBlob(toPath, false, "", "", false), UploadFileToStorage(ctx, fpath, toPath, size, u.store)
}

func (u Uploader) RecursiveUpload(ctx context.Context, vars *core.VariableMap, fromPath string, metaOutputPath, dataRawPath storage.DataReference) error {
Expand Down
3 changes: 3 additions & 0 deletions flyteidl/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ vendor

dist
gen/pb_python/flyteidl.egg-info/
gen/pb_python/flyteidl/**/__pycache__/
pip-wheel-metadata/*.dist-info/
*.egg-info/

.virtualgo
docs/build/
Expand Down
8 changes: 8 additions & 0 deletions flyteidl/clients/go/assets/admin.swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions flyteidl/clients/go/coreutils/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ func MakeLiteralForStructuredDataSet(path storage.DataReference, columns []*core
}
}

func MakeLiteralForBlob(path storage.DataReference, isDir bool, format string) *core.Literal {
func MakeLiteralForBlob(path storage.DataReference, isDir bool, format string, fileExtension string, enableLegacyFilename bool) *core.Literal {
dim := core.BlobType_SINGLE
if isDir {
dim = core.BlobType_MULTIPART
Expand All @@ -498,8 +498,10 @@ func MakeLiteralForBlob(path storage.DataReference, isDir bool, format string) *
Uri: path.String(),
Metadata: &core.BlobMetadata{
Type: &core.BlobType{
Dimensionality: dim,
Format: format,
Dimensionality: dim,
Format: format,
FileExtension: fileExtension,
EnableLegacyFilename: enableLegacyFilename,
},
},
},
Expand Down Expand Up @@ -601,7 +603,7 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro

case *core.LiteralType_Blob:
isDir := newT.Blob.GetDimensionality() == core.BlobType_MULTIPART
lv := MakeLiteralForBlob(storage.DataReference(fmt.Sprintf("%v", v)), isDir, newT.Blob.GetFormat())
lv := MakeLiteralForBlob(storage.DataReference(fmt.Sprintf("%v", v)), isDir, newT.Blob.GetFormat(), newT.Blob.GetFileExtension(), newT.Blob.GetEnableLegacyFilename())
return lv, nil

case *core.LiteralType_Schema:
Expand Down
12 changes: 8 additions & 4 deletions flyteidl/clients/go/coreutils/literals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,11 @@ func TestMakeLiteralForSimpleType(t *testing.T) {

func TestMakeLiteralForBlob(t *testing.T) {
type args struct {
path storage.DataReference
isDir bool
format string
path storage.DataReference
isDir bool
format string
fileExtension string
enableLegacyFilename bool
}
tests := []struct {
name string
Expand All @@ -399,10 +401,12 @@ func TestMakeLiteralForBlob(t *testing.T) {
}{
{"simple-key", args{path: "/key", isDir: false, format: "xyz"}, &core.Blob{Uri: "/key", Metadata: &core.BlobMetadata{Type: &core.BlobType{Format: "xyz", Dimensionality: core.BlobType_SINGLE}}}},
{"simple-dir", args{path: "/key", isDir: true, format: "xyz"}, &core.Blob{Uri: "/key", Metadata: &core.BlobMetadata{Type: &core.BlobType{Format: "xyz", Dimensionality: core.BlobType_MULTIPART}}}},
{"simple-key-with-extension", args{path: "/key", isDir: false, format: "xyz", fileExtension: "xyz", enableLegacyFilename: false}, &core.Blob{Uri: "/key", Metadata: &core.BlobMetadata{Type: &core.BlobType{Format: "xyz", Dimensionality: core.BlobType_SINGLE, FileExtension: "xyz", EnableLegacyFilename: false}}}},
{"simple-key-with-extension-and-legacy-filename", args{path: "/key", isDir: false, format: "xyz", fileExtension: "xyz", enableLegacyFilename: true}, &core.Blob{Uri: "/key", Metadata: &core.BlobMetadata{Type: &core.BlobType{Format: "xyz", Dimensionality: core.BlobType_SINGLE, FileExtension: "xyz", EnableLegacyFilename: true}}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := MakeLiteralForBlob(tt.args.path, tt.args.isDir, tt.args.format); !reflect.DeepEqual(got.GetScalar().GetBlob(), tt.want) {
if got := MakeLiteralForBlob(tt.args.path, tt.args.isDir, tt.args.format, tt.args.fileExtension, tt.args.enableLegacyFilename); !reflect.DeepEqual(got.GetScalar().GetBlob(), tt.want) {
t.Errorf("MakeLiteralForBlob() = %v, want %v", got, tt.want)
}
})
Expand Down
Loading
Loading