Skip to content

Commit

Permalink
windows: wait for pending start before stop
Browse files Browse the repository at this point in the history
Also buffer the OS signal so it's not potentially lost during Run.
  • Loading branch information
djdv committed Apr 24, 2021
1 parent ef35c56 commit 575d2ae
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 25 deletions.
16 changes: 11 additions & 5 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package service_test

import (
"fmt"
"os"
"testing"
"time"
Expand All @@ -22,22 +23,27 @@ func TestRunInterrupt(t *testing.T) {
t.Fatalf("New err: %s", err)
}

retChan := make(chan error)
go func() {
if err = s.Run(); err != nil {
retChan <- fmt.Errorf("Run() err: %w", err)
}
}()
go func() {
<-time.After(1 * time.Second)
interruptProcess(t)
}()

go func() {
for i := 0; i < 25 && p.numStopped == 0; i++ {
<-time.After(200 * time.Millisecond)
}
if p.numStopped == 0 {
t.Fatal("Run() hasn't been stopped")
retChan <- fmt.Errorf("Run() hasn't been stopped")
}
retChan <- nil
}()

if err = s.Run(); err != nil {
t.Fatalf("Run() err: %s", err)
if err = <-retChan; err != nil {
t.Fatal(err)
}
}

Expand Down
100 changes: 80 additions & 20 deletions service_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func (ws *windowsService) Run() error {
return err
}

sigChan := make(chan os.Signal)
sigChan := make(chan os.Signal, 1)

signal.Notify(sigChan, os.Interrupt)

Expand Down Expand Up @@ -381,6 +381,8 @@ func (ws *windowsService) Restart() error {
}
defer s.Close()

// First stop the service. Then wait for the service to
// actually stop before starting it.
err = ws.stopWait(s)
if err != nil {
return err
Expand All @@ -389,32 +391,90 @@ func (ws *windowsService) Restart() error {
return s.Start()
}

// statusInterval retreives a (bounded) duration from the status,
// or provides a default.
func statusInterval(status svc.Status) time.Duration {
// MSDN:
// "Do not wait longer than the wait hint. A good interval is
// one-tenth of the wait hint but not less than 1 second
// and not more than 10 seconds."
const (
lower = time.Second
upper = time.Second * 10
)

waitDuration := (time.Duration(status.WaitHint) * time.Millisecond) / 10
if waitDuration < lower {
waitDuration = lower
} else if waitDuration > upper {
waitDuration = upper
}
return waitDuration
}

// waitForStateChange polls the service until its state matches the desiredState,
// encounters an error, or times out.
// The initialTimeout may be extended if the service responds with checkpoints.
func waitForStateChange(s *mgr.Service, initialTimeout time.Duration, currentStatus svc.Status, desiredState svc.State) error {
var (
queryErr error
lastCheck = currentStatus.CheckPoint
queryInterval = statusInterval(currentStatus)
queryTicker = time.NewTicker(queryInterval)
queryTimeout = time.NewTimer(initialTimeout)
)
defer func() {
queryTicker.Stop()
queryTimeout.Stop()
}()

for currentStatus.State != desiredState {
select {
case <-queryTicker.C:
currentStatus, queryErr = s.Query()
if queryErr != nil {
return queryErr
}
// If the service is providing hints, use them.
if currentStatus.CheckPoint > lastCheck {
lastCheck = currentStatus.CheckPoint
if !queryTimeout.Stop() {
<-queryTimeout.C
}
// Start progressed,
// give the service more time to complete.
queryTimeout.Reset(getStopTimeout() +
statusInterval(currentStatus))
}
case <-queryTimeout.C:
return fmt.Errorf("status poll timed out before service state reached %v", desiredState)
}
}
return nil
}

func (ws *windowsService) stopWait(s *mgr.Service) error {
// First stop the service. Then wait for the service to
// actually stop before starting it.
status, err := s.Control(svc.Stop)
status, err := s.Query()
if err != nil {
return err
}

timeDuration := time.Millisecond * 50

timeout := time.After(getStopTimeout() + (timeDuration * 2))
tick := time.NewTicker(timeDuration)
defer tick.Stop()

for status.State != svc.Stopped {
select {
case <-tick.C:
status, err = s.Query()
if err != nil {
return err
}
case <-timeout:
break
if status.State == svc.StartPending {
// Service cannot be stopped before it is started.
// Wait for it to complete first.
initialTimeout := statusInterval(status)
if err = waitForStateChange(s, initialTimeout, status, svc.Running); err != nil {
return err
}
}
return nil

status, err = s.Control(svc.Stop)
if err != nil {
return err
}

initialTimeout := getStopTimeout() + statusInterval(status)
return waitForStateChange(s, initialTimeout, status, svc.Stopped)
}

// getStopTimeout fetches the time before windows will kill the service.
Expand Down

0 comments on commit 575d2ae

Please sign in to comment.