Skip to content

Commit

Permalink
fix: Allow provisionerd to cleanup acquired job (#159)
Browse files Browse the repository at this point in the history
If a job is acquired from the database, then provisionerd was
killed, the job would be left in an idle state where it was
technically in-progress.
  • Loading branch information
kylecarbs committed Feb 4, 2022
1 parent 94f71fe commit 2eab1b5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 65 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"goleak",
"hashicorp",
"httpmw",
"Jobf",
"moby",
"nhooyr",
"nolint",
Expand Down
118 changes: 53 additions & 65 deletions provisionerd/provisionerd.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ func New(clientDialer Dialer, opts *Options) io.Closer {
clientDialer: clientDialer,
opts: opts,

closeContext: ctx,
closeCancel: ctxCancel,
closed: make(chan struct{}),
closeCancel: ctxCancel,
closed: make(chan struct{}),

jobRunning: make(chan struct{}),
}
Expand All @@ -71,23 +70,21 @@ type provisionerDaemon struct {
client proto.DRPCProvisionerDaemonClient
updateStream proto.DRPCProvisionerDaemon_UpdateJobClient

closeContext context.Context
closeCancel context.CancelFunc
closed chan struct{}
closeMutex sync.Mutex
closeError error
// Locked when closing the daemon.
closeMutex sync.Mutex
closeCancel context.CancelFunc
closed chan struct{}
closeError error

jobID string
// Locked when acquiring or canceling a job.
jobMutex sync.Mutex
jobID string
jobRunning chan struct{}
jobCancel context.CancelFunc
}

// Connect establishes a connection to coderd.
func (p *provisionerDaemon) connect(ctx context.Context) {
p.jobMutex.Lock()
defer p.jobMutex.Unlock()

var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
Expand All @@ -102,6 +99,9 @@ func (p *provisionerDaemon) connect(ctx context.Context) {
}
p.updateStream, err = p.client.UpdateJob(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
p.opts.Logger.Warn(context.Background(), "create update job stream", slog.Error(err))
continue
}
Expand All @@ -126,12 +126,6 @@ func (p *provisionerDaemon) connect(ctx context.Context) {
// has been interrupted. This works well, because logs need
// to buffer if a job is running in the background.
p.opts.Logger.Debug(context.Background(), "update stream ended", slog.Error(p.updateStream.Context().Err()))
// Make sure we're not closing here!
p.closeMutex.Lock()
defer p.closeMutex.Unlock()
if p.isClosed() {
return
}
p.connect(ctx)
}
}()
Expand Down Expand Up @@ -168,6 +162,9 @@ func (p *provisionerDaemon) isRunningJob() bool {
func (p *provisionerDaemon) acquireJob(ctx context.Context) {
p.jobMutex.Lock()
defer p.jobMutex.Unlock()
if p.isClosed() {
return
}
if p.isRunningJob() {
p.opts.Logger.Debug(context.Background(), "skipping acquire; job is already running")
return
Expand All @@ -184,15 +181,10 @@ func (p *provisionerDaemon) acquireJob(ctx context.Context) {
p.opts.Logger.Warn(context.Background(), "acquire job", slog.Error(err))
return
}
if p.isClosed() {
return
}
if job.JobId == "" {
p.opts.Logger.Debug(context.Background(), "no jobs available")
return
}
p.closeMutex.Lock()
defer p.closeMutex.Unlock()
ctx, p.jobCancel = context.WithCancel(ctx)
p.jobRunning = make(chan struct{})
p.jobID = job.JobId
Expand Down Expand Up @@ -222,31 +214,27 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob)
JobId: job.JobId,
})
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("send periodic update: %s", err))
go p.cancelActiveJobf("send periodic update: %s", err)
return
}
}
}()
defer func() {
// Cleanup the work directory after execution.
err := os.RemoveAll(p.opts.WorkDirectory)
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("remove all from %q directory: %s", p.opts.WorkDirectory, err))
return
}
p.opts.Logger.Debug(ctx, "cleaned up work directory")
p.opts.Logger.Debug(ctx, "cleaned up work directory", slog.Error(err))
close(p.jobRunning)
}()
// It's safe to cast this ProvisionerType. This data is coming directly from coderd.
provisioner, hasProvisioner := p.opts.Provisioners[job.Provisioner]
if !hasProvisioner {
go p.cancelActiveJob(fmt.Sprintf("provisioner %q not registered", job.Provisioner))
go p.cancelActiveJobf("provisioner %q not registered", job.Provisioner)
return
}

err := os.MkdirAll(p.opts.WorkDirectory, 0700)
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("create work directory %q: %s", p.opts.WorkDirectory, err))
go p.cancelActiveJobf("create work directory %q: %s", p.opts.WorkDirectory, err)
return
}

Expand All @@ -258,13 +246,13 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob)
break
}
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("read project source archive: %s", err))
go p.cancelActiveJobf("read project source archive: %s", err)
return
}
// #nosec
path := filepath.Join(p.opts.WorkDirectory, header.Name)
if !strings.HasPrefix(path, filepath.Clean(p.opts.WorkDirectory)) {
go p.cancelActiveJob("tar attempts to target relative upper directory")
go p.cancelActiveJobf("tar attempts to target relative upper directory")
return
}
mode := header.FileInfo().Mode()
Expand All @@ -275,14 +263,14 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob)
case tar.TypeDir:
err = os.MkdirAll(path, mode)
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("mkdir %q: %s", path, err))
go p.cancelActiveJobf("mkdir %q: %s", path, err)
return
}
p.opts.Logger.Debug(context.Background(), "extracted directory", slog.F("path", path))
case tar.TypeReg:
file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, mode)
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("create file %q (mode %s): %s", path, mode, err))
go p.cancelActiveJobf("create file %q (mode %s): %s", path, mode, err)
return
}
// Max file size of 10MB.
Expand All @@ -291,12 +279,12 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob)
err = nil
}
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("copy file %q: %s", path, err))
go p.cancelActiveJobf("copy file %q: %s", path, err)
return
}
err = file.Close()
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("close file %q: %s", path, err))
go p.cancelActiveJobf("close file %q: %s", path, err)
return
}
p.opts.Logger.Debug(context.Background(), "extracted file",
Expand All @@ -323,7 +311,7 @@ func (p *provisionerDaemon) runJob(ctx context.Context, job *proto.AcquiredJob)

p.runWorkspaceProvision(ctx, provisioner, job)
default:
go p.cancelActiveJob(fmt.Sprintf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(job.Type).String()))
go p.cancelActiveJobf("unknown job type %q; ensure your provisioner daemon is up-to-date", reflect.TypeOf(job.Type).String())
return
}

Expand All @@ -335,14 +323,14 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd
Directory: p.opts.WorkDirectory,
})
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("parse source: %s", err))
go p.cancelActiveJobf("parse source: %s", err)
return
}
defer stream.Close()
for {
msg, err := stream.Recv()
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("recv parse source: %s", err))
go p.cancelActiveJobf("recv parse source: %s", err)
return
}
switch msgType := msg.Type.(type) {
Expand All @@ -363,7 +351,7 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd
}},
})
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("update job: %s", err))
go p.cancelActiveJobf("update job: %s", err)
return
}
case *sdkproto.Parse_Response_Complete:
Expand All @@ -379,14 +367,14 @@ func (p *provisionerDaemon) runProjectImport(ctx context.Context, provisioner sd
},
})
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("complete job: %s", err))
go p.cancelActiveJobf("complete job: %s", err)
return
}
// Return so we stop looping!
return
default:
go p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner",
reflect.TypeOf(msg.Type).String()))
go p.cancelActiveJobf("invalid message type %q received from provisioner",
reflect.TypeOf(msg.Type).String())
return
}
}
Expand All @@ -399,15 +387,15 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision
State: job.GetWorkspaceProvision().State,
})
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("provision: %s", err))
go p.cancelActiveJobf("provision: %s", err)
return
}
defer stream.Close()

for {
msg, err := stream.Recv()
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("recv workspace provision: %s", err))
go p.cancelActiveJobf("recv workspace provision: %s", err)
return
}
switch msgType := msg.Type.(type) {
Expand All @@ -428,7 +416,7 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision
}},
})
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("send job update: %s", err))
go p.cancelActiveJobf("send job update: %s", err)
return
}
case *sdkproto.Provision_Response_Complete:
Expand All @@ -450,26 +438,28 @@ func (p *provisionerDaemon) runWorkspaceProvision(ctx context.Context, provision
},
})
if err != nil {
go p.cancelActiveJob(fmt.Sprintf("complete job: %s", err))
go p.cancelActiveJobf("complete job: %s", err)
return
}
// Return so we stop looping!
return
default:
go p.cancelActiveJob(fmt.Sprintf("invalid message type %q received from provisioner",
reflect.TypeOf(msg.Type).String()))
go p.cancelActiveJobf("invalid message type %q received from provisioner",
reflect.TypeOf(msg.Type).String())
return
}
}
}

func (p *provisionerDaemon) cancelActiveJob(errMsg string) {
func (p *provisionerDaemon) cancelActiveJobf(format string, args ...interface{}) {
p.jobMutex.Lock()
defer p.jobMutex.Unlock()
if p.isClosed() {
return
}
errMsg := fmt.Sprintf(format, args...)
if !p.isRunningJob() {
if p.isClosed() {
// We don't want to log if we're already closed!
return
}
p.opts.Logger.Warn(context.Background(), "skipping job cancel; none running", slog.F("error_message", errMsg))
return
}
Expand Down Expand Up @@ -512,22 +502,20 @@ func (p *provisionerDaemon) closeWithError(err error) error {
if p.isClosed() {
return p.closeError
}
p.closeCancel()
p.closeError = err
close(p.closed)

errMsg := "provisioner daemon was shutdown gracefully"
if err != nil {
errMsg = err.Error()
}
p.cancelActiveJob(errMsg)
p.jobMutex.Lock()
defer p.jobMutex.Unlock()
p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err))
p.closeError = err
close(p.closed)
p.cancelActiveJobf(errMsg)
p.closeCancel()

if p.updateStream != nil {
_ = p.client.DRPCConn().Close()
_ = p.updateStream.Close()
}
// Required until we're on Go 1.18. See:
// https://github.com/golang/go/issues/50510
_ = os.RemoveAll(p.opts.WorkDirectory)
p.opts.Logger.Debug(context.Background(), "closing server with error", slog.Error(err))

return err
}

0 comments on commit 2eab1b5

Please sign in to comment.