This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 60
/
evaluator.go
139 lines (126 loc) · 4.86 KB
/
evaluator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package branch
import (
"context"
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
"github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
"github.com/lyft/flytepropeller/pkg/controller/nodes/errors"
"github.com/lyft/flytepropeller/pkg/controller/nodes/handler"
"github.com/lyft/flytestdlib/logger"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
regErrors "github.com/pkg/errors"
)
func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *handler.Data) (bool, error) {
var lValue *core.Literal
var rValue *core.Literal
var lPrim *core.Primitive
var rPrim *core.Primitive
if expr.GetLeftValue().GetPrimitive() == nil {
if nodeInputs == nil {
return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()]
if lValue == nil {
return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
} else {
lPrim = expr.GetLeftValue().GetPrimitive()
}
if expr.GetRightValue().GetPrimitive() == nil {
if nodeInputs == nil {
return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()]
if rValue == nil {
return false, regErrors.Errorf("Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar())
}
} else {
rPrim = expr.GetRightValue().GetPrimitive()
}
if lValue != nil && rValue != nil {
return EvaluateLiterals(lValue, rValue, expr.GetOperator())
}
if lValue != nil && rPrim != nil {
return Evaluate2(lValue, rPrim, expr.GetOperator())
}
if lPrim != nil && rValue != nil {
return Evaluate1(lPrim, rValue, expr.GetOperator())
}
return Evaluate(lPrim, rPrim, expr.GetOperator())
}
func EvaluateBooleanExpression(expr *core.BooleanExpression, nodeInputs *handler.Data) (bool, error) {
if expr.GetComparison() != nil {
return EvaluateComparison(expr.GetComparison(), nodeInputs)
}
if expr.GetConjunction() == nil {
return false, regErrors.Errorf("No Comparison or Conjunction found in Branch node expression.")
}
lvalue, err := EvaluateBooleanExpression(expr.GetConjunction().GetLeftExpression(), nodeInputs)
if err != nil {
return false, err
}
rvalue, err := EvaluateBooleanExpression(expr.GetConjunction().GetRightExpression(), nodeInputs)
if err != nil {
return false, err
}
if expr.GetConjunction().GetOperator() == core.ConjunctionExpression_OR {
return lvalue || rvalue, nil
}
return lvalue && rvalue, nil
}
func EvaluateIfBlock(block v1alpha1.ExecutableIfBlock, nodeInputs *handler.Data, skippedNodeIds []*v1alpha1.NodeID) (*v1alpha1.NodeID, []*v1alpha1.NodeID, error) {
if ok, err := EvaluateBooleanExpression(block.GetCondition(), nodeInputs); err != nil {
return nil, skippedNodeIds, err
} else if ok {
// Set status to running
return block.GetThenNode(), skippedNodeIds, err
}
// This branch is not taken
return nil, append(skippedNodeIds, block.GetThenNode()), nil
}
// Decides the branch to be taken, returns the nodeId of the selected node or an error
// The branchnode is marked as success. This is used by downstream node to determine if it can be executed
// All downstream nodes are marked as skipped
func DecideBranch(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, nodeID v1alpha1.NodeID, node v1alpha1.ExecutableBranchNode, nodeInputs *handler.Data) (*v1alpha1.NodeID, error) {
var selectedNodeID *v1alpha1.NodeID
var skippedNodeIds []*v1alpha1.NodeID
var err error
selectedNodeID, skippedNodeIds, err = EvaluateIfBlock(node.GetIf(), nodeInputs, skippedNodeIds)
if err != nil {
return nil, err
}
for _, block := range node.GetElseIf() {
if selectedNodeID != nil {
skippedNodeIds = append(skippedNodeIds, block.GetThenNode())
} else {
selectedNodeID, skippedNodeIds, err = EvaluateIfBlock(block, nodeInputs, skippedNodeIds)
if err != nil {
return nil, err
}
}
}
if node.GetElse() != nil {
if selectedNodeID == nil {
selectedNodeID = node.GetElse()
} else {
skippedNodeIds = append(skippedNodeIds, node.GetElse())
}
}
for _, nodeIDPtr := range skippedNodeIds {
skippedNodeID := *nodeIDPtr
n, ok := w.GetNode(skippedNodeID)
if !ok {
return nil, errors.Errorf(errors.DownstreamNodeNotFoundError, nodeID, "Downstream node [%v] not found", skippedNodeID)
}
nStatus := w.GetNodeExecutionStatus(n.GetID())
logger.Infof(ctx, "Branch Setting Node[%v] status to Skipped!", skippedNodeID)
nStatus.UpdatePhase(v1alpha1.NodePhaseSkipped, v1.Now(), "Branch evaluated to false")
}
if selectedNodeID == nil {
if node.GetElseFail() != nil {
return nil, errors.Errorf(errors.UserProvidedError, nodeID, node.GetElseFail().Message)
}
return nil, errors.Errorf(errors.NoBranchTakenError, nodeID, "No branch satisfied")
}
logger.Infof(ctx, "Branch Node[%v] selected!", *selectedNodeID)
return selectedNodeID, nil
}