/
patch_node_handler.go
161 lines (143 loc) · 4.03 KB
/
patch_node_handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package actions
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/client-go/kubernetes"
"github.com/castai/cluster-controller/castai"
)
func newPatchNodeHandler(log logrus.FieldLogger, clientset kubernetes.Interface) ActionHandler {
return &patchNodeHandler{
log: log,
clientset: clientset,
}
}
type patchNodeHandler struct {
log logrus.FieldLogger
clientset kubernetes.Interface
}
func (h *patchNodeHandler) Handle(ctx context.Context, action *castai.ClusterAction) error {
req, ok := action.Data().(*castai.ActionPatchNode)
if !ok {
return fmt.Errorf("unexpected type %T for delete patch handler", action.Data())
}
for k := range req.Labels {
if k == "" {
return errors.New("labels contain entry with empty key")
}
}
for k := range req.Annotations {
if k == "" {
return errors.New("annotations contain entry with empty key")
}
}
for _, t := range req.Taints {
if t.Key == "" {
return errors.New("taint contain entry with empty key")
}
}
log := h.log.WithFields(logrus.Fields{
"node_name": req.NodeName,
"node_id": req.NodeID,
"action": reflect.TypeOf(action.Data().(*castai.ActionPatchNode)).String(),
actionIDLogField: action.ID,
})
node, err := getNodeForPatching(ctx, h.log, h.clientset, req.NodeName)
if err != nil {
if apierrors.IsNotFound(err) {
log.WithError(err).Infof("node not found, skipping patch")
return nil
}
return err
}
unschedulable := "<nil>"
if req.Unschedulable != nil {
unschedulable = strconv.FormatBool(*req.Unschedulable)
}
if req.Unschedulable == nil && len(req.Labels) == 0 && len(req.Taints) == 0 && len(req.Annotations) == 0 {
log.Info("no patch for node spec or labels")
} else {
log.WithFields(map[string]interface{}{
"labels": req.Labels,
"taints": req.Taints,
"annotations": req.Annotations,
"capacity": req.Capacity,
}).Infof("patching node, labels=%v, taints=%v, annotations=%v, unschedulable=%v", req.Labels, req.Taints, req.Annotations, unschedulable)
err = patchNode(ctx, h.log, h.clientset, node, func(n *v1.Node) {
n.Labels = patchNodeMapField(n.Labels, req.Labels)
n.Annotations = patchNodeMapField(n.Annotations, req.Annotations)
n.Spec.Taints = patchTaints(n.Spec.Taints, req.Taints)
n.Spec.Unschedulable = patchUnschedulable(n.Spec.Unschedulable, req.Unschedulable)
})
if err != nil {
return err
}
}
if len(req.Capacity) > 0 {
log.WithField("capacity", req.Capacity).Infof("patching node status")
patch, err := json.Marshal(map[string]interface{}{
"status": map[string]interface{}{
"capacity": req.Capacity,
},
})
if err != nil {
return fmt.Errorf("marshal patch for status: %w", err)
}
return patchNodeStatus(ctx, h.log, h.clientset, node.Name, patch)
}
return nil
}
func patchNodeMapField(values map[string]string, patch map[string]string) map[string]string {
if values == nil {
values = map[string]string{}
}
for k, v := range patch {
if k[0] == '-' {
delete(values, k[1:])
} else {
values[k] = v
}
}
return values
}
func patchTaints(taints []v1.Taint, patch []castai.NodeTaint) []v1.Taint {
for _, v := range patch {
taint := &v1.Taint{Key: v.Key, Value: v.Value, Effect: v1.TaintEffect(v.Effect)}
if v.Key[0] == '-' {
taint.Key = taint.Key[1:]
taints = deleteTaint(taints, taint)
} else if _, found := findTaint(taints, taint); !found {
taints = append(taints, *taint)
}
}
return taints
}
func patchUnschedulable(unschedulable bool, patch *bool) bool {
if patch != nil {
return *patch
}
return unschedulable
}
func findTaint(taints []v1.Taint, t *v1.Taint) (v1.Taint, bool) {
for _, taint := range taints {
if taint.MatchTaint(t) {
return taint, true
}
}
return v1.Taint{}, false
}
func deleteTaint(taints []v1.Taint, t *v1.Taint) []v1.Taint {
var res []v1.Taint
for _, taint := range taints {
if !taint.MatchTaint(t) {
res = append(res, taint)
}
}
return res
}