-
Notifications
You must be signed in to change notification settings - Fork 4
/
worker.go
286 lines (227 loc) · 7.31 KB
/
worker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
// Package worker manages the execution and queue of tasks.
package worker
import (
"context"
"errors"
"fmt"
"path/filepath"
"sync"
"time"
log "github.com/canonical/ubuntu-pro-for-wsl/common/grpc/logstreamer"
"github.com/canonical/ubuntu-pro-for-wsl/windows-agent/internal/distros/task"
"github.com/ubuntu/decorate"
)
type distro interface {
Name() string
LockAwake() error
ReleaseAwake() error
IsValid() bool
Invalidate(context.Context)
}
// Connection encapsulates the logic behind sending and receiving messages
// with the WSL-Pro-Service.
type Connection interface {
SendProAttachment(proToken string) error
SendLandscapeConfig(lpeConfig, hostagentUID string) error
Close()
}
// Worker contains all the logic around task queueing and execution for one particular distro.
type Worker struct {
distro distro
manager *taskManager
cancel context.CancelFunc
processing chan struct{}
conn Connection
connMu sync.RWMutex
}
// New creates a new worker and starts it. Call Stop when you're done to avoid leaking the task execution goroutine.
func New(ctx context.Context, d distro, storageDir string) (w *Worker, err error) {
defer decorate.OnError(&err, "distro %q: could not create worker", d.Name())
storagePath := filepath.Join(storageDir, d.Name()+".tasks")
tm, err := newTaskManager(storagePath)
if err != nil {
return nil, err
}
w = &Worker{
distro: d,
manager: tm,
}
w.start(ctx)
return w, nil
}
// IsActive returns true when the worker is running, and there exists an active
// connection to its GRPC service.
func (w *Worker) IsActive() bool {
w.connMu.RLock()
defer w.connMu.RUnlock()
return w.conn != nil
}
// Connection returns the client to the WSL task service.
// Connection returns nil when no connection is set up.
func (w *Worker) Connection() Connection {
w.connMu.RLock()
defer w.connMu.RUnlock()
return w.conn
}
// SetConnection removes the connection associated with the distro.
func (w *Worker) SetConnection(conn Connection) {
w.connMu.Lock()
defer w.connMu.Unlock()
if w.conn != nil {
w.conn.Close()
}
w.conn = conn
}
// start starts the main task processing goroutine.
func (w *Worker) start(ctx context.Context) {
log.Debugf(ctx, "Distro %q: starting task processing", w.distro.Name())
ctx, cancel := context.WithCancel(ctx)
w.processing = make(chan struct{})
go w.processTasks(ctx)
w.cancel = cancel
}
// Stop stops the main task processing goroutine and wait for it to be done.
func (w *Worker) Stop(ctx context.Context) {
log.Debugf(ctx, "Distro %q: stopping task processing", w.distro.Name())
w.cancel()
<-w.processing
w.SetConnection(nil)
}
// SubmitTasks enqueues one or more task on our current worker list. The task will wake up
// the distro and be performed as soon as it reaches the beginning of the queue.
//
// It will return an error if the distro has been cleaned up or the task queue is full.
func (w *Worker) SubmitTasks(tasks ...task.Task) (err error) {
defer decorate.OnError(&err, "distro %q: tasks %q: could not submit", w.distro.Name(), tasks)
if len(tasks) == 0 {
return nil
}
log.Infof(context.TODO(), "Distro %q: Submitting tasks %q to queue", w.distro.Name(), tasks)
return w.manager.Submit(false, tasks...)
}
// SubmitDeferredTasks takes one or more tasks into our current worker list.
//
// The task(s) won't wake up the distro, instead wait until it is awake. This does
// NOT necessarily mean it'll run after non-deferred tasks.
//
// It will return an error if the distro has been cleaned up.
func (w *Worker) SubmitDeferredTasks(tasks ...task.Task) (err error) {
defer decorate.OnError(&err, "distro %q: tasks %q: could not submit", w.distro.Name(), tasks)
if len(tasks) == 0 {
return nil
}
log.Infof(context.TODO(), "Distro %q: Submitting tasks %q to queue", w.distro.Name(), tasks)
return w.manager.Submit(true, tasks...)
}
// EnqueueDeferredTasks takes all deferred tasks and promotes them
// to regular tasks.
func (w *Worker) EnqueueDeferredTasks() {
w.manager.EnqueueDeferredTasks()
}
// processTasks is the main loop for the distro, processing any existing tasks while starting and releasing
// locks to distro,.
func (w *Worker) processTasks(ctx context.Context) {
defer close(w.processing)
for {
t, ok := w.manager.NextTask(ctx)
if !ok {
return
}
resultErr := w.processSingleTask(ctx, t)
var target unreachableDistroError
if errors.As(resultErr, &target) {
log.Errorf(ctx, "Distro %q: task %q: distro not reachable: %v", w.distro.Name(), t, target.sourceErr)
w.distro.Invalidate(ctx)
continue
}
err := w.manager.TaskDone(ctx, t, resultErr)
if err != nil {
log.Errorf(ctx, "Distro %q: %v", w.distro.Name(), err)
}
}
}
type unreachableDistroError struct {
sourceErr error
}
func newUnreachableDistroErr(err error) error {
if err == nil {
return nil
}
return unreachableDistroError{
sourceErr: err,
}
}
func (err unreachableDistroError) Error() string {
return fmt.Sprintf("distro cannot be reached: %v", err.sourceErr)
}
func (w *Worker) processSingleTask(ctx context.Context, t task.Task) error {
log.Debugf(ctx, "Distro %q: starting task %q", w.distro.Name(), t)
if !w.distro.IsValid() {
return newUnreachableDistroErr(errors.New("distro marked as invalid"))
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if err := w.distro.LockAwake(); err != nil {
return newUnreachableDistroErr(err)
}
//nolint:errcheck // Nothing we can do about it
defer w.distro.ReleaseAwake()
log.Debugf(ctx, "Distro %q: distro is running.", w.distro.Name())
client, err := w.waitForActiveConnection(ctx)
if err != nil {
return fmt.Errorf("task %v: could not start task: %w", t, err)
}
if err := t.Execute(ctx, client); err != nil {
return fmt.Errorf("distro %q: task %q failed: %w", w.distro.Name(), t, err)
}
log.Debugf(ctx, "Distro %q: task %q: task completed successfully", w.distro.Name(), t)
return nil
}
func (w *Worker) waitForActiveConnection(ctx context.Context) (conn Connection, err error) {
log.Debugf(ctx, "Distro %q: ensuring active connection.", w.distro.Name())
for range 5 {
conn, err = func() (Connection, error) {
// Potentially restart distro if it was stopped for some reason
if err := w.distro.LockAwake(); err != nil {
return nil, newUnreachableDistroErr(err)
}
//nolint:errcheck // Nothing we can do about it
defer w.distro.ReleaseAwake()
// Connect to GRPC client
client, err := w.waitForClient(ctx)
if err != nil {
return nil, err
}
log.Debugf(ctx, "Distro %q: connection is active.", w.distro.Name())
return client, nil
}()
if err == nil {
break
}
}
return conn, err
}
// waitForClient waits for a valid GRPC client to connect to. It will retry for a while before
// erroring out.
func (w *Worker) waitForClient(ctx context.Context) (Connection, error) {
timedOutCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
tickRate := 0 * time.Second
for {
select {
case <-ctx.Done():
// Context cancelled means agent teardown.
return nil, fmt.Errorf("stopped waiting for client: %v", timedOutCtx.Err())
case <-timedOutCtx.Done():
// Timeout means the distro is not reachable.
return nil, newUnreachableDistroErr(errors.New("timed out waiting for client"))
case <-time.After(tickRate):
conn := w.Connection()
if conn == nil {
tickRate = time.Second
continue
}
return conn, nil
}
}
}