Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing caching on maptasks when using partials #4344

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 15 additions & 27 deletions flyteplugins/go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,17 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
return state, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size")
}

// identify and validate the size of the array job
size := -1
var literalCollection *idlCore.LiteralCollection
literals := make([][]*idlCore.Literal, 0)
discoveredInputNames := make([]string, 0)
for inputName, literal := range inputs.Literals {
for _, literal := range inputs.Literals {
if literalCollection = literal.GetCollection(); literalCollection != nil {
// validate length of input list
if size != -1 && size != len(literalCollection.Literals) {
state = state.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason("all maptask input lists must be the same length")
return state, nil
}

literals = append(literals, literalCollection.Literals)
discoveredInputNames = append(discoveredInputNames, inputName)

size = len(literalCollection.Literals)
}
}
Expand All @@ -110,7 +106,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
arrayJobSize = int64(size)

// build input readers
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literals, discoveredInputNames)
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), inputs.Literals, size)
}

if arrayJobSize > maxArrayJobSize {
Expand Down Expand Up @@ -246,18 +242,7 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state
return state, externalResources, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size")
}

var literalCollection *idlCore.LiteralCollection
literals := make([][]*idlCore.Literal, 0)
discoveredInputNames := make([]string, 0)
for inputName, literal := range inputs.Literals {
if literalCollection = literal.GetCollection(); literalCollection != nil {
literals = append(literals, literalCollection.Literals)
discoveredInputNames = append(discoveredInputNames, inputName)
}
}

// build input readers
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), literals, discoveredInputNames)
inputReaders = ConstructStaticInputReaders(tCtx.InputReader(), inputs.Literals, arrayJobSize)
}

// output reader
Expand Down Expand Up @@ -476,16 +461,19 @@ func ConstructCatalogReaderWorkItems(ctx context.Context, taskReader core.TaskRe

// ConstructStaticInputReaders constructs input readers that comply with the io.InputReader interface but have their
// inputs already populated.
func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputs [][]*idlCore.Literal, inputNames []string) []io.InputReader {
inputReaders := make([]io.InputReader, 0, len(inputs))
if len(inputs) == 0 {
return inputReaders
}
func ConstructStaticInputReaders(inputPaths io.InputFilePaths, inputLiterals map[string]*idlCore.Literal, arrayJobSize int) []io.InputReader {
var literalCollection *idlCore.LiteralCollection

for i := 0; i < len(inputs[0]); i++ {
inputReaders := make([]io.InputReader, 0, arrayJobSize)
for i := 0; i < arrayJobSize; i++ {
literals := make(map[string]*idlCore.Literal)
for j := 0; j < len(inputNames); j++ {
literals[inputNames[j]] = inputs[j][i]
for inputName, inputLiteral := range inputLiterals {
if literalCollection = inputLiteral.GetCollection(); literalCollection != nil {
// if literal is a collection then we need to retrieve the specific literal for this subtask index
literals[inputName] = literalCollection.Literals[i]
} else {
literals[inputName] = inputLiteral
}
}

inputReaders = append(inputReaders, NewStaticInputReader(inputPaths, &idlCore.LiteralMap{Literals: literals}))
Expand Down
7 changes: 6 additions & 1 deletion flytepropeller/pkg/controller/nodes/array/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,12 @@ func (a *arrayNodeHandler) buildArrayNodeContext(ctx context.Context, nCtx inter
taskPhase := int(arrayNodeState.SubNodeTaskPhases.GetItem(subNodeIndex))

// need to initialize the inputReader every time to ensure TaskHandler can access for cache lookups / population
inputLiteralMap, err := constructLiteralMap(ctx, nCtx.InputReader(), subNodeIndex)
inputs, err := nCtx.InputReader().Get(ctx)
if err != nil {
return nil, nil, nil, nil, nil, nil, err
}

inputLiteralMap, err := constructLiteralMap(inputs, subNodeIndex)
if err != nil {
return nil, nil, nil, nil, nil, nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package array

import (
"context"
"fmt"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io"
Expand All @@ -26,16 +27,16 @@ func newStaticInputReader(inputPaths io.InputFilePaths, input *core.LiteralMap)
}
}

func constructLiteralMap(ctx context.Context, inputReader io.InputReader, index int) (*core.LiteralMap, error) {
inputs, err := inputReader.Get(ctx)
if err != nil {
return nil, err
}

func constructLiteralMap(inputs *core.LiteralMap, index int) (*core.LiteralMap, error) {
literals := make(map[string]*core.Literal)
for name, literal := range inputs.Literals {
if literalCollection := literal.GetCollection(); literalCollection != nil {
if index >= len(literalCollection.Literals) {
return nil, fmt.Errorf("index %v out of bounds for literal collection %v", index, name)
}
literals[name] = literalCollection.Literals[index]
} else {
literals[name] = literal
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package array

import (
"testing"

"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

var (
literalOne = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 1,
},
},
},
},
},
}
literalTwo = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: &core.Primitive{
Value: &core.Primitive_Integer{
Integer: 2,
},
},
},
},
},
}
)

func TestConstructLiteralMap(t *testing.T) {
tests := []struct {
name string
inputLiteralMaps *core.LiteralMap
expectedLiteralMaps []*core.LiteralMap
}{
{
"SingleList",
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
literalOne,
literalTwo,
},
},
},
},
},
},
[]*core.LiteralMap{
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalOne,
},
},
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalTwo,
},
},
},
},
{
"MultiList",
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
literalOne,
literalTwo,
},
},
},
},
"bar": &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
literalTwo,
literalOne,
},
},
},
},
},
},
[]*core.LiteralMap{
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalOne,
"bar": literalTwo,
},
},
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalTwo,
"bar": literalOne,
},
},
},
},
{
"Partial",
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": &core.Literal{
Value: &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: []*core.Literal{
literalOne,
literalTwo,
},
},
},
},
"bar": literalTwo,
},
},
[]*core.LiteralMap{
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalOne,
"bar": literalTwo,
},
},
&core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": literalTwo,
"bar": literalTwo,
},
},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for i := 0; i < len(test.expectedLiteralMaps); i++ {
outputLiteralMap, err := constructLiteralMap(test.inputLiteralMaps, i)
assert.NoError(t, err)
assert.True(t, proto.Equal(test.expectedLiteralMaps[i], outputLiteralMap))
}
})
}
}
Loading