Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,11 @@ func (t *Trap) matches(c *apiCall) bool {
func (t *Trap) Close() {
t.mock.mu.Lock()
defer t.mock.mu.Unlock()
select {
case <-t.done:
return // already closed
default:
}
if t.unreleasedCalls != 0 {
t.mock.tb.Helper()
t.mock.tb.Errorf("trap Closed() with %d unreleased calls", t.unreleasedCalls)
Expand Down
61 changes: 60 additions & 1 deletion mock_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package quartz_test

import (
"bytes"
"context"
"errors"
"fmt"
"os"
"runtime/pprof"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -403,7 +407,10 @@ func Test_UnreleasedCalls(t *testing.T) {
_ = mClock.Now()
}()

trap.MustWait(testCtx) // missing release
c := trap.MustWait(testCtx) // missing release
trap.Close() // detect unreleased call and fail

c.Release(testCtx) // clean up goroutine
})
}

Expand Down Expand Up @@ -573,3 +580,55 @@ func TestTickerStop_Go123(t *testing.T) {
// OK!
}
}

func TestMain(m *testing.M) {
verifyNoLeakTestMain(m)
}

func verifyNoLeakTestMain(m *testing.M) {
before := snapshot()
code := m.Run()
now := time.Now()
for {
after := snapshot()
if len(after) > len(before) {
// Allow test cleanup to settle.
if time.Since(now) < 200*time.Millisecond {
time.Sleep(50 * time.Millisecond)
continue
}
fmt.Fprintln(os.Stderr, "Possible goroutine leak(s):")
fmt.Fprintln(os.Stderr, diff(before, after))
os.Exit(1)
}
os.Exit(code)
}
}

func snapshot() []string {
var buf bytes.Buffer
_ = pprof.Lookup("goroutine").WriteTo(&buf, 2)
var clean []string
for _, s := range strings.Split(buf.String(), "\n\n") {
if !strings.Contains(s, "runtime/pprof") {
clean = append(clean, s)
}
}
return clean
}

func diff(a, b []string) string {
m := make(map[string]int)
for _, s := range a {
m[s]++
}
var leaks []string
for _, s := range b {
if m[s] > 0 {
m[s]--
continue
}
leaks = append(leaks, s)
}
return strings.Join(leaks, "\n\n")
}