Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a way to ignore portions of the signature for cache key calcu… #4324

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions flyteidl/protos/flyteidl/core/tasks.proto
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ message TaskMetadata {
// task creates a k8s Pod. If this value is set, the specified PodTemplate will be used instead of, but applied
// identically as, the default PodTemplate configured in FlytePropeller.
string pod_template_name = 12;

// cache_ignore_input_vars is the input variables that should not be included when calculating hash for cache.
repeated string cache_ignore_input_vars = 13;
troychiu marked this conversation as resolved.
Show resolved Hide resolved
}

// A Task structure that uniquely identifies a task in the system
Expand Down
9 changes: 5 additions & 4 deletions flyteplugins/go/tasks/pluginmachinery/catalog/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ type Metadata struct {

// An identifier for a catalog object.
type Key struct {
Identifier core.Identifier
CacheVersion string
TypedInterface core.TypedInterface
InputReader io.InputReader
Identifier core.Identifier
CacheVersion string
CacheIgnoreInputVars []string
TypedInterface core.TypedInterface
InputReader io.InputReader
}

func (k Key) String() string {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (m *CatalogClient) Get(ctx context.Context, key catalog.Key) (catalog.Entry
inputs = retInputs
}

tag, err := GenerateArtifactTagName(ctx, inputs)
tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
if err != nil {
logger.Errorf(ctx, "DataCatalog failed to generate tag for inputs %+v, err: %+v", inputs, err)
return catalog.Entry{}, err
Expand Down Expand Up @@ -233,7 +233,7 @@ func (m *CatalogClient) CreateArtifact(ctx context.Context, key catalog.Key, dat
logger.Debugf(ctx, "Created artifact: %v, with %v outputs from execution %+v", cachedArtifact.Id, len(artifactDataList), metadata)

// Tag the artifact since it is the cached artifact
tagName, err := GenerateArtifactTagName(ctx, inputs)
tagName, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
if err != nil {
logger.Errorf(ctx, "Failed to generate tag for artifact %+v, err: %+v", cachedArtifact.Id, err)
return catalog.Status{}, err
Expand Down Expand Up @@ -273,7 +273,7 @@ func (m *CatalogClient) UpdateArtifact(ctx context.Context, key catalog.Key, dat
artifactDataList = append(artifactDataList, artifactData)
}

tagName, err := GenerateArtifactTagName(ctx, inputs)
tagName, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
if err != nil {
logger.Errorf(ctx, "Failed to generate artifact tag name for key %+v, dataset %+v and execution %+v, err: %+v", key, datasetID, metadata, err)
return catalog.Status{}, err
Expand Down Expand Up @@ -378,7 +378,7 @@ func (m *CatalogClient) GetOrExtendReservation(ctx context.Context, key catalog.
inputs = retInputs
}

tag, err := GenerateArtifactTagName(ctx, inputs)
tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -418,7 +418,7 @@ func (m *CatalogClient) ReleaseReservation(ctx context.Context, key catalog.Key,
inputs = retInputs
}

tag, err := GenerateArtifactTagName(ctx, inputs)
tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/catalog"
"github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators"
"github.com/flyteorg/flyte/flytestdlib/pbhash"
"golang.org/x/exp/slices"
)

const cachedTaskTag = "flyte_cached"
Expand Down Expand Up @@ -114,9 +115,21 @@ func generateTaskSignatureHash(ctx context.Context, taskInterface core.TypedInte
return fmt.Sprintf("%v-%v", inputHashString, outputHashString), nil
}

// Generate a tag by hashing the input values
func GenerateArtifactTagName(ctx context.Context, inputs *core.LiteralMap) (string, error) {
hashString, err := catalog.HashLiteralMap(ctx, inputs)
// Generate a tag by hashing the input values which are not in cacheIgnoreInputVars
func GenerateArtifactTagName(ctx context.Context, inputs *core.LiteralMap, cacheIgnoreInputVars *[]string) (string, error) {
var inputsAfterIgnore *core.LiteralMap
if cacheIgnoreInputVars != nil {
inputsAfterIgnore = &core.LiteralMap{Literals: make(map[string]*core.Literal)}
for name, literal := range inputs.Literals {
if slices.Contains(*cacheIgnoreInputVars, name) {
continue
}
inputsAfterIgnore.Literals[name] = literal
}
} else {
inputsAfterIgnore = inputs
}
Comment on lines +120 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than creating a new LiteralMap here would it make more sense to pass the cacheIgnoreInputVars to the HashLiteralMap function. Then in this code here we just need to update to:

	for name, literal := range literalMap.Literals {
	        if !slices.Contains(cacheIgnoreInputVars, name) {
		    hashifiedLiteralMap[name] = hashify(literal)
		}
	}

This achieves the same outcome in 2 LoC rather than 12. I think it makes more sense to only create a new LiteralMap in one place as well.

hashString, err := catalog.HashLiteralMap(ctx, inputsAfterIgnore)
if err != nil {
return "", err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ func TestGenerateArtifactTagName(t *testing.T) {
literalMap, err := coreutils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2})
assert.NoError(t, err)

tag, err := GenerateArtifactTagName(context.TODO(), literalMap)
tag, err := GenerateArtifactTagName(context.TODO(), literalMap, nil)
assert.NoError(t, err)
assert.Equal(t, "flyte_cached-GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", tag)
}

func TestGenerateArtifactTagNameWithIgnore(t *testing.T) {
literalMap, err := coreutils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2, "3": 3})
assert.NoError(t, err)
cacheIgnoreInputVars := []string{"3"}
tag, err := GenerateArtifactTagName(context.TODO(), literalMap, &cacheIgnoreInputVars)
assert.NoError(t, err)
assert.Equal(t, "flyte_cached-GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", tag)
}
Expand Down
9 changes: 5 additions & 4 deletions flytepropeller/pkg/controller/nodes/task/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
}

return catalog.Key{
Identifier: *taskTemplate.Id,
CacheVersion: taskTemplate.Metadata.DiscoveryVersion,
TypedInterface: *taskTemplate.Interface,
InputReader: nCtx.InputReader(),
Identifier: *taskTemplate.Id,
CacheVersion: taskTemplate.Metadata.DiscoveryVersion,
CacheIgnoreInputVars: taskTemplate.Metadata.CacheIgnoreInputVars,

Check failure on line 31 in flytepropeller/pkg/controller/nodes/task/cache.go

View workflow job for this annotation

GitHub Actions / compile

taskTemplate.Metadata.CacheIgnoreInputVars undefined (type *"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core".TaskMetadata has no field or method CacheIgnoreInputVars)
TypedInterface: *taskTemplate.Interface,
InputReader: nCtx.InputReader(),
}, nil
}

Expand Down