Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
persisting k8s plugin state between evaluations (#540)
Browse files Browse the repository at this point in the history
* persisting k8s plugin state between evaluations

Signed-off-by: Daniel Rammer <daniel@union.ai>

* fixed unit tests and linter

Signed-off-by: Daniel Rammer <daniel@union.ai>

* added docs

Signed-off-by: Daniel Rammer <daniel@union.ai>

* updating flyteplugins dep

Signed-off-by: Daniel Rammer <daniel@union.ai>

* added unit tests

Signed-off-by: Daniel Rammer <daniel@union.ai>

---------

Signed-off-by: Daniel Rammer <daniel@union.ai>
  • Loading branch information
hamersaw committed Mar 30, 2023
1 parent b1e5482 commit 01218d2
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 15 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1
github.com/fatih/color v1.13.0
github.com/flyteorg/flyteidl v1.3.14
github.com/flyteorg/flyteplugins v1.0.43
github.com/flyteorg/flyteplugins v1.0.44
github.com/flyteorg/flytestdlib v1.0.15
github.com/ghodss/yaml v1.0.0
github.com/go-redis/redis v6.15.7+incompatible
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYF
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/flyteorg/flyteidl v1.3.14 h1:o5M0g/r6pXTPu5PEurbYxbQmuOu3hqqsaI2M6uvK0N8=
github.com/flyteorg/flyteidl v1.3.14/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM=
github.com/flyteorg/flyteplugins v1.0.43 h1:uI/Y88xqJKfvfuxfu0Sw9CNZ7iu3+HUwwRhxh558cbs=
github.com/flyteorg/flyteplugins v1.0.43/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08=
github.com/flyteorg/flyteplugins v1.0.44 h1:uKizng+i0vfXslyPBlrsfecInhvy71fTB4kRg7eiifE=
github.com/flyteorg/flyteplugins v1.0.44/go.mod h1:ztsonku5fKwyxcIg1k69PTiBVjRI6d3nK5DnC+iwx08=
github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0=
github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s=
github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk=
Expand Down
35 changes: 33 additions & 2 deletions pkg/controller/nodes/task/k8s/plugin_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package k8s

import (
"context"
"fmt"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
Expand All @@ -15,7 +16,8 @@ var _ k8s.PluginContext = &pluginContext{}
type pluginContext struct {
pluginsCore.TaskExecutionContext
// Lazily creates a buffered outputWriter, overriding the input outputWriter.
ow *ioutils.BufferedOutputWriter
ow *ioutils.BufferedOutputWriter
k8sPluginState *k8s.PluginState
}

// Provides an output sync of type io.OutputWriter
Expand All @@ -26,9 +28,38 @@ func (p *pluginContext) OutputWriter() io.OutputWriter {
return buf
}

func newPluginContext(tCtx pluginsCore.TaskExecutionContext) *pluginContext {
// pluginStateReader overrides the default PluginStateReader to return a pre-assigned PluginState. This allows us to
// encapsulate plugin state persistence in the existing k8s PluginManager and only expose the ability to read the
// previous Phase, PhaseVersion, and Reason for all k8s plugins.
type pluginStateReader struct {
k8sPluginState *k8s.PluginState
}

func (p pluginStateReader) GetStateVersion() uint8 {
return 0
}

func (p pluginStateReader) Get(t interface{}) (stateVersion uint8, err error) {
if pointer, ok := t.(*k8s.PluginState); ok {
*pointer = *p.k8sPluginState
} else {
return 0, fmt.Errorf("unexpected type when reading plugin state")
}

return 0, nil
}

// PluginStateReader overrides the default behavior to return our k8s plugin specific reader.
func (p *pluginContext) PluginStateReader() pluginsCore.PluginStateReader {
return pluginStateReader{
k8sPluginState: p.k8sPluginState,
}
}

func newPluginContext(tCtx pluginsCore.TaskExecutionContext, k8sPluginState *k8s.PluginState) *pluginContext {
return &pluginContext{
TaskExecutionContext: tCtx,
ow: nil,
k8sPluginState: k8sPluginState,
}
}
50 changes: 40 additions & 10 deletions pkg/controller/nodes/task/k8s/plugin_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ const (
)

type PluginState struct {
Phase PluginPhase
Phase PluginPhase
K8sPluginState k8s.PluginState
}

type PluginMetrics struct {
Expand Down Expand Up @@ -247,7 +248,7 @@ func (e *PluginManager) LaunchResource(ctx context.Context, tCtx pluginsCore.Tas
return pluginsCore.DoTransition(pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "task submitted to K8s")), nil
}

func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) {
func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, k8sPluginState *k8s.PluginState) (pluginsCore.Transition, error) {

o, err := e.plugin.BuildIdentityResource(ctx, tCtx.TaskExecutionMetadata())
if err != nil {
Expand All @@ -274,7 +275,7 @@ func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore
e.metrics.ResourceDeleted.Inc(ctx)
}

pCtx := newPluginContext(tCtx)
pCtx := newPluginContext(tCtx, k8sPluginState)
p, err := e.plugin.GetTaskPhase(ctx, pCtx, o)
if err != nil {
logger.Warnf(ctx, "failed to check status of resource in plugin [%s], with error: %s", e.GetID(), err.Error())
Expand Down Expand Up @@ -311,23 +312,52 @@ func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore
}

func (e PluginManager) Handle(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) {
// read phase state
ps := PluginState{}
if v, err := tCtx.PluginStateReader().Get(&ps); err != nil {
if v != pluginStateVersion {
return pluginsCore.DoTransition(pluginsCore.PhaseInfoRetryableFailure(errors.CorruptedPluginState, fmt.Sprintf("plugin state version mismatch expected [%d] got [%d]", pluginStateVersion, v), nil)), nil
}
return pluginsCore.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

// evaluate plugin
var err error
var transition pluginsCore.Transition
pluginPhase := ps.Phase
if ps.Phase == PluginPhaseNotStarted {
t, err := e.LaunchResource(ctx, tCtx)
if err == nil && t.Info().Phase() == pluginsCore.PhaseQueued {
if err := tCtx.PluginStateWriter().Put(pluginStateVersion, &PluginState{Phase: PluginPhaseStarted}); err != nil {
return pluginsCore.UnknownTransition, err
}
transition, err = e.LaunchResource(ctx, tCtx)
if err == nil && transition.Info().Phase() == pluginsCore.PhaseQueued {
pluginPhase = PluginPhaseStarted
}
return t, err
} else {
transition, err = e.CheckResourcePhase(ctx, tCtx, &ps.K8sPluginState)
}

if err != nil {
return transition, err
}
return e.CheckResourcePhase(ctx, tCtx)

// persist any changes in phase state
k8sPluginState := ps.K8sPluginState
if ps.Phase != pluginPhase || k8sPluginState.Phase != transition.Info().Phase() ||
k8sPluginState.PhaseVersion != transition.Info().Version() || k8sPluginState.Reason != transition.Info().Reason() {

newPluginState := PluginState{
Phase: pluginPhase,
K8sPluginState: k8s.PluginState{
Phase: transition.Info().Phase(),
PhaseVersion: transition.Info().Version(),
Reason: transition.Info().Reason(),
},
}

if err := tCtx.PluginStateWriter().Put(pluginStateVersion, &newPluginState); err != nil {
return pluginsCore.UnknownTransition, err
}
}

return transition, err
}

func (e PluginManager) Abort(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) error {
Expand Down
152 changes: 152 additions & 0 deletions pkg/controller/nodes/task/k8s/plugin_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"reflect"
"testing"

"k8s.io/client-go/kubernetes/scheme"
Expand Down Expand Up @@ -715,6 +716,157 @@ func TestPluginManager_Handle_CheckResourceStatus(t *testing.T) {
}
}

func TestPluginManager_Handle_PluginState(t *testing.T) {
ctx := context.TODO()
tm := getMockTaskExecutionMetadata()
res := &v1.Pod{
ObjectMeta: v12.ObjectMeta{
Name: tm.GetTaskExecutionID().GetGeneratedName(),
Namespace: tm.GetNamespace(),
},
}

pluginStateQueued := PluginState{
Phase: PluginPhaseStarted,
K8sPluginState: k8s.PluginState{
Phase: pluginsCore.PhaseQueued,
PhaseVersion: 0,
Reason: "foo",
},
}
pluginStateQueuedVersion1 := PluginState{
Phase: PluginPhaseStarted,
K8sPluginState: k8s.PluginState{
Phase: pluginsCore.PhaseQueued,
PhaseVersion: 1,
Reason: "foo",
},
}
pluginStateQueuedReasonBar := PluginState{
Phase: PluginPhaseStarted,
K8sPluginState: k8s.PluginState{
Phase: pluginsCore.PhaseQueued,
PhaseVersion: 0,
Reason: "bar",
},
}
pluginStateRunning := PluginState{
Phase: PluginPhaseStarted,
K8sPluginState: k8s.PluginState{
Phase: pluginsCore.PhaseRunning,
PhaseVersion: 0,
Reason: "",
},
}

phaseInfoQueued := pluginsCore.PhaseInfoQueuedWithTaskInfo(pluginStateQueued.K8sPluginState.PhaseVersion, pluginStateQueued.K8sPluginState.Reason, nil)
phaseInfoQueuedVersion1 := pluginsCore.PhaseInfoQueuedWithTaskInfo(
pluginStateQueuedVersion1.K8sPluginState.PhaseVersion,
pluginStateQueuedVersion1.K8sPluginState.Reason,
nil,
)
phaseInfoQueuedReasonBar := pluginsCore.PhaseInfoQueuedWithTaskInfo(
pluginStateQueuedReasonBar.K8sPluginState.PhaseVersion,
pluginStateQueuedReasonBar.K8sPluginState.Reason,
nil,
)
phaseInfoRunning := pluginsCore.PhaseInfoRunning(0, nil)

tests := []struct {
name string
startPluginState PluginState
reportedPhaseInfo pluginsCore.PhaseInfo
expectedPluginState PluginState
}{
{
"NoChange",
pluginStateQueued,
phaseInfoQueued,
pluginStateQueued,
},
{
"K8sPhaseChange",
pluginStateQueued,
phaseInfoRunning,
pluginStateRunning,
},
{
"PhaseVersionChange",
pluginStateQueued,
phaseInfoQueuedVersion1,
pluginStateQueuedVersion1,
},
{
"ReasonChange",
pluginStateQueued,
phaseInfoQueuedReasonBar,
pluginStateQueuedReasonBar,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// mock TaskExecutionContext
tCtx := &pluginsCoreMock.TaskExecutionContext{}
tCtx.OnTaskExecutionMetadata().Return(getMockTaskExecutionMetadata())

tReader := &pluginsCoreMock.TaskReader{}
tReader.OnReadMatch(mock.Anything).Return(&core.TaskTemplate{}, nil)
tCtx.OnTaskReader().Return(tReader)

// mock state reader / writer to use local pluginState variable
pluginState := &tt.startPluginState
customStateReader := &pluginsCoreMock.PluginStateReader{}
customStateReader.OnGetMatch(mock.MatchedBy(func(i interface{}) bool {
ps, ok := i.(*PluginState)
if ok {
*ps = *pluginState
return true
}
return false
})).Return(uint8(0), nil)
tCtx.OnPluginStateReader().Return(customStateReader)

customStateWriter := &pluginsCoreMock.PluginStateWriter{}
customStateWriter.OnPutMatch(mock.Anything, mock.MatchedBy(func(i interface{}) bool {
ps, ok := i.(*PluginState)
if ok {
*pluginState = *ps
}
return ok
})).Return(nil)
tCtx.OnPluginStateWriter().Return(customStateWriter)
tCtx.OnOutputWriter().Return(&dummyOutputWriter{})

fc := extendedFakeClient{Client: fake.NewFakeClient(res)}

mockResourceHandler := &pluginsk8sMock.Plugin{}
mockResourceHandler.OnGetProperties().Return(k8s.PluginProperties{})
mockResourceHandler.On("BuildIdentityResource", mock.Anything, tCtx.TaskExecutionMetadata()).Return(&v1.Pod{}, nil)
mockResourceHandler.On("GetTaskPhase", mock.Anything, mock.Anything, mock.Anything).Return(tt.reportedPhaseInfo, nil)

// create new PluginManager
pluginManager, err := NewPluginManager(ctx, dummySetupContext(fc), k8s.PluginEntry{
ID: "x",
ResourceToWatch: &v1.Pod{},
Plugin: mockResourceHandler,
}, NewResourceMonitorIndex())
assert.NoError(t, err)

// handle plugin
_, err = pluginManager.Handle(ctx, tCtx)
assert.NoError(t, err)

// verify expected PluginState
newPluginState := PluginState{}
_, err = tCtx.PluginStateReader().Get(&newPluginState)
assert.NoError(t, err)

assert.True(t, reflect.DeepEqual(newPluginState, tt.expectedPluginState))
})
}
}

func TestPluginManager_CustomKubeClient(t *testing.T) {
ctx := context.TODO()
tctx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseStarted)
Expand Down

0 comments on commit 01218d2

Please sign in to comment.