From 366119e41ca55b68fa1dd9fe7c258fa80137a616 Mon Sep 17 00:00:00 2001 From: myzhan Date: Mon, 11 Mar 2019 16:06:15 +0800 Subject: [PATCH] Expose the Boomer type and the RateLimiter interface Now users are able to control more options programatically instead of passing command line parameters. And they can write their own rate limiters. The cache of current time in second is removed, it may be a premature optimization. Resolves: #54, #56 --- boomer.go | 215 +++++++++++++++--------------- boomer_test.go | 65 ++++----- events.go | 28 ---- events_test.go | 34 ----- examples/main.go | 34 ++++- examples/ratelimit/maxrps/main.go | 58 ++++++++ examples/ratelimit/rampup/main.go | 58 ++++++++ legacy.go | 61 +++++++-- legacy_test.go | 14 +- ratelimiter.go | 78 +++++++---- ratelimiter_test.go | 40 +++--- runner.go | 51 +++---- runner_test.go | 9 +- stats.go | 4 +- stats_test.go | 26 ---- utils.go | 43 ++++++ utils_test.go | 27 ++++ 17 files changed, 500 insertions(+), 345 deletions(-) delete mode 100644 events.go delete mode 100644 events_test.go create mode 100644 examples/ratelimit/maxrps/main.go create mode 100644 examples/ratelimit/rampup/main.go diff --git a/boomer.go b/boomer.go index fec9746..896ff05 100644 --- a/boomer.go +++ b/boomer.go @@ -3,45 +3,102 @@ package boomer import ( "flag" "log" - "math" "os" "os/signal" - "runtime/pprof" "strings" - "sync" - "sync/atomic" "syscall" "time" + + "github.com/asaskevich/EventBus" ) -var masterHost string -var masterPort int +// Events is the global event bus instance. +var Events = EventBus.New() -var maxRPS int64 -var requestIncreaseRate string -var hatchType string -var runTasks string -var memoryProfile string -var memoryProfileDuration time.Duration -var cpuProfile string -var cpuProfileDuration time.Duration +var defaultBoomer *Boomer -var defaultRunner *runner +// A Boomer is used to run tasks. +type Boomer struct { + masterHost string + masterPort int -var initiated uint32 -var initMutex = sync.Mutex{} + hatchType string + rateLimiter RateLimiter + runner *runner +} -// Init boomer -func initBoomer() { - if atomic.LoadUint32(&initiated) == 1 { - panic("Don't call boomer.Run() more than once.") +// NewBoomer returns a new Boomer. +func NewBoomer(masterHost string, masterPort int) *Boomer { + return &Boomer{ + masterHost: masterHost, + masterPort: masterPort, + hatchType: "asap", } +} - // TODO: to be removed - initLegacyEventHandlers() +// SetRateLimiter allows user to use their own rate limiter. +// It must be called before the test is started. +func (b *Boomer) SetRateLimiter(rateLimiter RateLimiter) { + b.rateLimiter = rateLimiter +} + +// SetHatchType only accepts "asap" or "smooth". +// "asap" means spawning goroutines as soon as possible when the test is started. +// "smooth" means a constant pace. +func (b *Boomer) SetHatchType(hatchType string) { + if hatchType != "asap" && hatchType != "smooth" { + log.Printf("Wrong hatch-type, expected asap or smooth, was %s\n", hatchType) + return + } + b.hatchType = hatchType +} + +func (b *Boomer) setRunner(runner *runner) { + b.runner = runner +} + +// Run accepts a slice of Task and connects to the locust master. +func (b *Boomer) Run(tasks ...*Task) { + b.runner = newRunner(tasks, b.rateLimiter, b.hatchType) + b.runner.masterHost = b.masterHost + b.runner.masterPort = b.masterPort + b.runner.getReady() +} + +// RecordSuccess reports a success +func (b *Boomer) RecordSuccess(requestType, name string, responseTime int64, responseLength int64) { + b.runner.stats.requestSuccessChannel <- &requestSuccess{ + requestType: requestType, + name: name, + responseTime: responseTime, + responseLength: responseLength, + } +} - // done - atomic.StoreUint32(&initiated, 1) +// RecordFailure reports a failure. +func (b *Boomer) RecordFailure(requestType, name string, responseTime int64, exception string) { + b.runner.stats.requestFailureChannel <- &requestFailure{ + requestType: requestType, + name: name, + responseTime: responseTime, + error: exception, + } +} + +// Quit will send a quit message to the master. +func (b *Boomer) Quit() { + Events.Publish("boomer:quit") + var ticker = time.NewTicker(3 * time.Second) + for { + // wait for quit message is sent to master + select { + case <-b.runner.client.disconnectedChannel(): + return + case <-ticker.C: + log.Println("Timeout waiting for sending quit message to master, boomer will quit any way.") + return + } + } } // Run tasks without connecting to the master. @@ -61,26 +118,8 @@ func runTasksForTest(tasks ...*Task) { } } -func createRateLimiter(maxRPS int64, requestIncreaseRate string) (rateLimiter rateLimiter, err error) { - if requestIncreaseRate != "-1" { - if maxRPS > 0 { - log.Println("The max RPS that boomer may generate is limited to", maxRPS, "with a increase rate", requestIncreaseRate) - rateLimiter, err = newRampUpRateLimiter(maxRPS, requestIncreaseRate, time.Second) - } else { - log.Println("The max RPS that boomer may generate is limited by a increase rate", requestIncreaseRate) - rateLimiter, err = newRampUpRateLimiter(math.MaxInt64, requestIncreaseRate, time.Second) - } - } else { - if maxRPS > 0 { - log.Println("The max RPS that boomer may generate is limited to", maxRPS) - rateLimiter = newStableRateLimiter(maxRPS, time.Second) - } - } - return rateLimiter, err -} - -// Run accepts a slice of Task and connects -// to a locust master. +// Run accepts a slice of Task and connects to a locust master. +// It's a convenience function to use the defaultBoomer. func Run(tasks ...*Task) { if !flag.Parsed() { flag.Parse() @@ -91,85 +130,43 @@ func Run(tasks ...*Task) { return } - // init boomer - initMutex.Lock() - initBoomer() - initMutex.Unlock() - - rateLimiter, err := createRateLimiter(maxRPS, requestIncreaseRate) - if err != nil { - log.Fatalf("Failed to create rate limiter, %v\n", err) - } - - defaultRunner = newRunner(tasks, rateLimiter, hatchType) - defaultRunner.masterHost = masterHost - defaultRunner.masterPort = masterPort - defaultRunner.getReady() + defaultBoomer = NewBoomer(masterHost, masterPort) + initLegacyEventHandlers() if memoryProfile != "" { - startMemoryProfile(memoryProfile, memoryProfileDuration) + StartMemoryProfile(memoryProfile, memoryProfileDuration) } if cpuProfile != "" { - startCPUProfile(cpuProfile, cpuProfileDuration) + StartCPUProfile(cpuProfile, cpuProfileDuration) } + rateLimiter, err := createRateLimiter(maxRPS, requestIncreaseRate) + if err != nil { + log.Fatalf("%v\n", err) + } + defaultBoomer.SetRateLimiter(rateLimiter) + defaultBoomer.hatchType = hatchType + + defaultBoomer.Run(tasks...) + c := make(chan os.Signal) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) <-c - Events.Publish("boomer:quit") + defaultBoomer.Quit() - // wait for quit message is sent to master - <-defaultRunner.client.disconnectedChannel() log.Println("shut down") } -func startMemoryProfile(file string, duration time.Duration) { - f, err := os.Create(file) - if err != nil { - log.Fatal(err) - } - time.AfterFunc(duration, func() { - err = pprof.WriteHeapProfile(f) - if err != nil { - log.Println(err) - return - } - f.Close() - log.Println("Stop memory profiling after", duration) - }) -} - -func startCPUProfile(file string, duration time.Duration) { - f, err := os.Create(file) - if err != nil { - log.Fatal(err) - } - - err = pprof.StartCPUProfile(f) - if err != nil { - log.Println(err) - f.Close() - return - } - - time.AfterFunc(duration, func() { - pprof.StopCPUProfile() - f.Close() - log.Println("Stop CPU profiling after", duration) - }) +// RecordSuccess reports a success. +// It's a convenience function to use the defaultBoomer. +func RecordSuccess(requestType, name string, responseTime int64, responseLength int64) { + defaultBoomer.RecordSuccess(requestType, name, responseTime, responseLength) } -func init() { - flag.Int64Var(&maxRPS, "max-rps", 0, "Max RPS that boomer can generate, disabled by default.") - flag.StringVar(&requestIncreaseRate, "request-increase-rate", "-1", "Request increase rate, disabled by default.") - flag.StringVar(&hatchType, "hatch-type", "asap", "How to create goroutines according to hatch rate, 'asap' will do it as soon as possible while 'smooth' means a constant pace.") - flag.StringVar(&runTasks, "run-tasks", "", "Run tasks without connecting to the master, multiply tasks is separated by comma. Usually, it's for debug purpose.") - flag.StringVar(&masterHost, "master-host", "127.0.0.1", "Host or IP address of locust master for distributed load testing.") - flag.IntVar(&masterPort, "master-port", 5557, "The port to connect to that is used by the locust master for distributed load testing.") - flag.StringVar(&memoryProfile, "mem-profile", "", "Enable memory profiling.") - flag.DurationVar(&memoryProfileDuration, "mem-profile-duration", 30*time.Second, "Memory profile duration.") - flag.StringVar(&cpuProfile, "cpu-profile", "", "Enable CPU profiling.") - flag.DurationVar(&cpuProfileDuration, "cpu-profile-duration", 30*time.Second, "CPU profile duration.") +// RecordFailure reports a failure. +// It's a convenience function to use the defaultBoomer. +func RecordFailure(requestType, name string, responseTime int64, exception string) { + defaultBoomer.RecordFailure(requestType, name, responseTime, exception) } diff --git a/boomer_test.go b/boomer_test.go index c4e084c..e773422 100644 --- a/boomer_test.go +++ b/boomer_test.go @@ -2,25 +2,10 @@ package boomer import ( "math" - "os" "testing" "time" ) -func TestInitBoomer(t *testing.T) { - initBoomer() - defer Events.Unsubscribe("request_success", legacySuccessHandler) - defer Events.Unsubscribe("request_failure", legacyFailureHandler) - - defer func() { - err := recover() - if err == nil { - t.Error("It should panic if initBoomer is called more than once.") - } - }() - initBoomer() -} - func TestRunTasksForTest(t *testing.T) { count := 0 taskA := &Task{ @@ -36,6 +21,7 @@ func TestRunTasksForTest(t *testing.T) { }, } runTasks = "increaseCount,foobar" + runTasksForTest(taskA, taskWithoutName) if count != 1 { @@ -45,7 +31,7 @@ func TestRunTasksForTest(t *testing.T) { func TestCreateRatelimiter(t *testing.T) { rateLimiter, _ := createRateLimiter(100, "-1") - if stableRateLimiter, ok := rateLimiter.(*stableRateLimiter); !ok { + if stableRateLimiter, ok := rateLimiter.(*StableRateLimiter); !ok { t.Error("Expected stableRateLimiter") } else { if stableRateLimiter.threshold != 100 { @@ -54,7 +40,7 @@ func TestCreateRatelimiter(t *testing.T) { } rateLimiter, _ = createRateLimiter(0, "1") - if rampUpRateLimiter, ok := rateLimiter.(*rampUpRateLimiter); !ok { + if rampUpRateLimiter, ok := rateLimiter.(*RampUpRateLimiter); !ok { t.Error("Expected rampUpRateLimiter") } else { if rampUpRateLimiter.maxThreshold != math.MaxInt64 { @@ -66,7 +52,7 @@ func TestCreateRatelimiter(t *testing.T) { } rateLimiter, _ = createRateLimiter(10, "2/2s") - if rampUpRateLimiter, ok := rateLimiter.(*rampUpRateLimiter); !ok { + if rampUpRateLimiter, ok := rateLimiter.(*RampUpRateLimiter); !ok { t.Error("Expected rampUpRateLimiter") } else { if rampUpRateLimiter.maxThreshold != 10 { @@ -84,28 +70,35 @@ func TestCreateRatelimiter(t *testing.T) { } } -func TestStartMemoryProfile(t *testing.T) { - if _, err := os.Stat("mem.pprof"); os.IsExist(err) { - os.Remove("mem.pprof") +func TestRecordSuccess(t *testing.T) { + defaultBoomer = NewBoomer("127.0.0.1", 5557) + defaultBoomer.runner = newRunner(nil, nil, "asap") + RecordSuccess("http", "foo", int64(1), int64(10)) + + requestSuccessMsg := <-defaultBoomer.runner.stats.requestSuccessChannel + if requestSuccessMsg.requestType != "http" { + t.Error("Expected: http, got:", requestSuccessMsg.requestType) } - startMemoryProfile("mem.pprof", 2*time.Second) - time.Sleep(2100 * time.Millisecond) - if _, err := os.Stat("mem.pprof"); os.IsNotExist(err) { - t.Error("File mem.pprof is not generated") - } else { - os.Remove("mem.pprof") + if requestSuccessMsg.responseTime != int64(1) { + t.Error("Expected: 1, got:", requestSuccessMsg.responseTime) } + defaultBoomer = nil } -func TestStartCPUProfile(t *testing.T) { - if _, err := os.Stat("cpu.pprof"); os.IsExist(err) { - os.Remove("cpu.pprof") +func TestRecordFailure(t *testing.T) { + defaultBoomer = NewBoomer("127.0.0.1", 5557) + defaultBoomer.runner = newRunner(nil, nil, "asap") + RecordFailure("udp", "bar", int64(2), "udp error") + + requestFailureMsg := <-defaultBoomer.runner.stats.requestFailureChannel + if requestFailureMsg.requestType != "udp" { + t.Error("Expected: udp, got:", requestFailureMsg.requestType) } - startCPUProfile("cpu.pprof", 2*time.Second) - time.Sleep(2100 * time.Millisecond) - if _, err := os.Stat("cpu.pprof"); os.IsNotExist(err) { - t.Error("File cpu.pprof is not generated") - } else { - os.Remove("cpu.pprof") + if requestFailureMsg.responseTime != int64(2) { + t.Error("Expected: 2, got:", requestFailureMsg.responseTime) + } + if requestFailureMsg.error != "udp error" { + t.Error("Expected: udp error, got:", requestFailureMsg.error) } + defaultBoomer = nil } diff --git a/events.go b/events.go deleted file mode 100644 index 12dda0e..0000000 --- a/events.go +++ /dev/null @@ -1,28 +0,0 @@ -package boomer - -import ( - "github.com/asaskevich/EventBus" -) - -// Events is core event bus instance of boomer -var Events = EventBus.New() - -// RecordSuccess reports a success -func RecordSuccess(requestType, name string, responseTime int64, responseLength int64) { - defaultRunner.stats.requestSuccessChannel <- &requestSuccess{ - requestType: requestType, - name: name, - responseTime: responseTime, - responseLength: responseLength, - } -} - -// RecordFailure reports a failure -func RecordFailure(requestType, name string, responseTime int64, exception string) { - defaultRunner.stats.requestFailureChannel <- &requestFailure{ - requestType: requestType, - name: name, - responseTime: responseTime, - error: exception, - } -} diff --git a/events_test.go b/events_test.go deleted file mode 100644 index cf90918..0000000 --- a/events_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package boomer - -import "testing" - -func TestRecordSuccess(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") - RecordSuccess("http", "foo", int64(1), int64(10)) - - requestSuccessMsg := <-defaultRunner.stats.requestSuccessChannel - if requestSuccessMsg.requestType != "http" { - t.Error("Expected: http, got:", requestSuccessMsg.requestType) - } - if requestSuccessMsg.responseTime != int64(1) { - t.Error("Expected: 1, got:", requestSuccessMsg.responseTime) - } - defaultRunner = nil -} - -func TestRecordFailure(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") - RecordFailure("udp", "bar", int64(2), "udp error") - - requestFailureMsg := <-defaultRunner.stats.requestFailureChannel - if requestFailureMsg.requestType != "udp" { - t.Error("Expected: udp, got:", requestFailureMsg.requestType) - } - if requestFailureMsg.responseTime != int64(2) { - t.Error("Expected: 2, got:", requestFailureMsg.responseTime) - } - if requestFailureMsg.error != "udp error" { - t.Error("Expected: udp error, got:", requestFailureMsg.error) - } - defaultRunner = nil -} diff --git a/examples/main.go b/examples/main.go index cea265e..9518c1e 100644 --- a/examples/main.go +++ b/examples/main.go @@ -2,33 +2,52 @@ package main import ( "log" + "os" + "os/signal" + "sync" + "syscall" "time" "github.com/myzhan/boomer" ) func foo() { - start := boomer.Now() time.Sleep(100 * time.Millisecond) elapsed := boomer.Now() - start // Report your test result as a success, if you write it in python, it will looks like this // events.request_success.fire(request_type="http", name="foo", response_time=100, response_length=10) - boomer.RecordSuccess("http", "foo", elapsed, int64(10)) + globalBoomer.RecordSuccess("http", "foo", elapsed, int64(10)) } func bar() { - start := boomer.Now() time.Sleep(100 * time.Millisecond) elapsed := boomer.Now() - start // Report your test result as a failure, if you write it in python, it will looks like this // events.request_failure.fire(request_type="udp", name="bar", response_time=100, exception=Exception("udp error")) - boomer.RecordFailure("udp", "bar", elapsed, "udp error") + globalBoomer.RecordFailure("udp", "bar", elapsed, "udp error") +} + +func waitForQuit() { + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + <-c + globalBoomer.Quit() + wg.Done() + }() + + wg.Wait() } +var globalBoomer = boomer.NewBoomer("127.0.0.1", 5557) + func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) @@ -40,9 +59,12 @@ func main() { task2 := &boomer.Task{ Name: "bar", - Weight: 20, + Weight: 30, Fn: bar, } - boomer.Run(task1, task2) + globalBoomer.Run(task1, task2) + + waitForQuit() + log.Println("shut down") } diff --git a/examples/ratelimit/maxrps/main.go b/examples/ratelimit/maxrps/main.go new file mode 100644 index 0000000..fea96d5 --- /dev/null +++ b/examples/ratelimit/maxrps/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "log" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/myzhan/boomer" +) + +func foo() { + start := boomer.Now() + time.Sleep(100 * time.Millisecond) + elapsed := boomer.Now() - start + + // Report your test result as a success, if you write it in python, it will looks like this + // events.request_success.fire(request_type="http", name="foo", response_time=100, response_length=10) + globalBoomer.RecordSuccess("http", "foo", elapsed, int64(10)) +} + +func waitForQuit() { + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + <-c + globalBoomer.Quit() + wg.Done() + }() + + wg.Wait() +} + +var globalBoomer = boomer.NewBoomer("127.0.0.1", 5557) + +func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + task1 := &boomer.Task{ + Name: "foo", + Weight: 10, + Fn: foo, + } + + ratelimiter := boomer.NewStableRateLimiter(100, time.Second) + log.Println("the max rps is limited to 100/s.") + globalBoomer.SetRateLimiter(ratelimiter) + + globalBoomer.Run(task1) + + waitForQuit() + log.Println("shut down") +} diff --git a/examples/ratelimit/rampup/main.go b/examples/ratelimit/rampup/main.go new file mode 100644 index 0000000..dec10d8 --- /dev/null +++ b/examples/ratelimit/rampup/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "log" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/myzhan/boomer" +) + +func foo() { + start := boomer.Now() + time.Sleep(100 * time.Millisecond) + elapsed := boomer.Now() - start + + // Report your test result as a success, if you write it in python, it will looks like this + // events.request_success.fire(request_type="http", name="foo", response_time=100, response_length=10) + globalBoomer.RecordSuccess("http", "foo", elapsed, int64(10)) +} + +func waitForQuit() { + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + <-c + globalBoomer.Quit() + wg.Done() + }() + + wg.Wait() +} + +var globalBoomer = boomer.NewBoomer("127.0.0.1", 5557) + +func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + task1 := &boomer.Task{ + Name: "foo", + Weight: 10, + Fn: foo, + } + + ratelimiter, _ := boomer.NewRampUpRateLimiter(1000, "100/1s", time.Second) + log.Println("the max rps is limited to 1000/s, with a rampup rate 100/1s.") + globalBoomer.SetRateLimiter(ratelimiter) + + globalBoomer.Run(task1) + + waitForQuit() + log.Println("shut down") +} diff --git a/legacy.go b/legacy.go index 16f1509..8bb8d5e 100644 --- a/legacy.go +++ b/legacy.go @@ -1,15 +1,49 @@ +// This file is kept to ensure backward compatibility. + package boomer import ( + "flag" "fmt" "log" + "math" "reflect" "sync" + "time" ) +var masterHost string +var masterPort int +var maxRPS int64 +var requestIncreaseRate string +var hatchType string +var runTasks string +var memoryProfile string +var memoryProfileDuration time.Duration +var cpuProfile string +var cpuProfileDuration time.Duration + var successRetiredWarning = &sync.Once{} var failureRetiredWarning = &sync.Once{} +func createRateLimiter(maxRPS int64, requestIncreaseRate string) (rateLimiter RateLimiter, err error) { + if requestIncreaseRate != "-1" { + if maxRPS > 0 { + log.Println("The max RPS that boomer may generate is limited to", maxRPS, "with a increase rate", requestIncreaseRate) + rateLimiter, err = NewRampUpRateLimiter(maxRPS, requestIncreaseRate, time.Second) + } else { + log.Println("The max RPS that boomer may generate is limited by a increase rate", requestIncreaseRate) + rateLimiter, err = NewRampUpRateLimiter(math.MaxInt64, requestIncreaseRate, time.Second) + } + } else { + if maxRPS > 0 { + log.Println("The max RPS that boomer may generate is limited to", maxRPS) + rateLimiter = NewStableRateLimiter(maxRPS, time.Second) + } + } + return rateLimiter, err +} + // According to locust, responseTime should be int64, in milliseconds. // But previous version of boomer required responseTime to be float64, so sad. func convertResponseTime(origin interface{}) int64 { @@ -28,27 +62,30 @@ func legacySuccessHandler(requestType string, name string, responseTime interfac successRetiredWarning.Do(func() { log.Println("boomer.Events.Publish(\"request_success\") is less performant and deprecated, use boomer.RecordSuccess() instead.") }) - defaultRunner.stats.requestSuccessChannel <- &requestSuccess{ - requestType: requestType, - name: name, - responseTime: convertResponseTime(responseTime), - responseLength: responseLength, - } + defaultBoomer.RecordSuccess(requestType, name, convertResponseTime(responseTime), responseLength) } func legacyFailureHandler(requestType string, name string, responseTime interface{}, exception string) { failureRetiredWarning.Do(func() { log.Println("boomer.Events.Publish(\"request_failure\") is less performant and deprecated, use boomer.RecordFailure() instead.") }) - defaultRunner.stats.requestFailureChannel <- &requestFailure{ - requestType: requestType, - name: name, - responseTime: convertResponseTime(responseTime), - error: exception, - } + defaultBoomer.RecordFailure(requestType, name, convertResponseTime(responseTime), exception) } func initLegacyEventHandlers() { Events.Subscribe("request_success", legacySuccessHandler) Events.Subscribe("request_failure", legacyFailureHandler) } + +func init() { + flag.Int64Var(&maxRPS, "max-rps", 0, "Max RPS that boomer can generate, disabled by default.") + flag.StringVar(&requestIncreaseRate, "request-increase-rate", "-1", "Request increase rate, disabled by default.") + flag.StringVar(&hatchType, "hatch-type", "asap", "How to create goroutines according to hatch rate, 'asap' will do it as soon as possible while 'smooth' means a constant pace.") + flag.StringVar(&runTasks, "run-tasks", "", "Run tasks without connecting to the master, multiply tasks is separated by comma. Usually, it's for debug purpose.") + flag.StringVar(&masterHost, "master-host", "127.0.0.1", "Host or IP address of locust master for distributed load testing.") + flag.IntVar(&masterPort, "master-port", 5557, "The port to connect to that is used by the locust master for distributed load testing.") + flag.StringVar(&memoryProfile, "mem-profile", "", "Enable memory profiling.") + flag.DurationVar(&memoryProfileDuration, "mem-profile-duration", 30*time.Second, "Memory profile duration.") + flag.StringVar(&cpuProfile, "cpu-profile", "", "Enable CPU profiling.") + flag.DurationVar(&cpuProfileDuration, "cpu-profile-duration", 30*time.Second, "CPU profile duration.") +} diff --git a/legacy_test.go b/legacy_test.go index 3b0b046..a299c41 100644 --- a/legacy_test.go +++ b/legacy_test.go @@ -24,13 +24,16 @@ func TestConvertResponseTime(t *testing.T) { func TestInitEvents(t *testing.T) { initLegacyEventHandlers() + defer Events.Unsubscribe("request_success", legacySuccessHandler) + defer Events.Unsubscribe("request_failure", legacyFailureHandler) - defaultRunner = newRunner(nil, nil, "asap") + defaultBoomer = NewBoomer("127.0.0.1", 5557) + defaultBoomer.runner = newRunner(nil, nil, "asap") Events.Publish("request_success", "http", "foo", int64(1), int64(10)) Events.Publish("request_failure", "udp", "bar", int64(2), "udp error") - requestSuccessMsg := <-defaultRunner.stats.requestSuccessChannel + requestSuccessMsg := <-defaultBoomer.runner.stats.requestSuccessChannel if requestSuccessMsg.requestType != "http" { t.Error("Expected: http, got:", requestSuccessMsg.requestType) } @@ -38,7 +41,7 @@ func TestInitEvents(t *testing.T) { t.Error("Expected: 1, got:", requestSuccessMsg.responseTime) } - requestFailureMsg := <-defaultRunner.stats.requestFailureChannel + requestFailureMsg := <-defaultBoomer.runner.stats.requestFailureChannel if requestFailureMsg.requestType != "udp" { t.Error("Expected: udp, got:", requestFailureMsg.requestType) } @@ -48,9 +51,4 @@ func TestInitEvents(t *testing.T) { if requestFailureMsg.error != "udp error" { t.Error("Expected: udp error, got:", requestFailureMsg.error) } - - Events.Unsubscribe("request_success", legacySuccessHandler) - Events.Unsubscribe("request_failure", legacyFailureHandler) - - defaultRunner = nil } diff --git a/ratelimiter.go b/ratelimiter.go index 607c36d..5cf98b9 100644 --- a/ratelimiter.go +++ b/ratelimiter.go @@ -9,34 +9,53 @@ import ( "time" ) -// runner uses a rate limiter to put limits on task executions. -type rateLimiter interface { - start() - acquire() bool - stop() +// RateLimiter is used to put limits on task executions. +type RateLimiter interface { + // Start is used to enable the rate limiter. + // It can be implemented as a noop if not needed. + Start() + + // Acquire() is called before executing a task.Fn function. + // If Acquire() returns true, the task.Fn function will be executed. + // If Acquire() returns false, the task.Fn function won't be executed this time, but Acquire() will be called very soon. + // It works like: + // for { + // blocked := rateLimiter.Acquire() + // if !blocked { + // task.Fn() + // } + // } + // Acquire() should block the caller until execution is allowed. + Acquire() bool + + // Stop is used to disable the rate limiter. + // It can be implemented as a noop if not needed. + Stop() } -// stableRateLimiter uses the token bucket algorithm. +// A StableRateLimiter uses the token bucket algorithm. // the bucket is refilled according to the refill period, no burst is allowed. -type stableRateLimiter struct { +type StableRateLimiter struct { threshold int64 currentThreshold int64 - refillPeroid time.Duration + refillPeriod time.Duration broadcastChannel chan bool quitChannel chan bool } -func newStableRateLimiter(threshold int64, refillPeroid time.Duration) (rateLimiter *stableRateLimiter) { - rateLimiter = &stableRateLimiter{ +// NewStableRateLimiter returns a StableRateLimiter. +func NewStableRateLimiter(threshold int64, refillPeriod time.Duration) (rateLimiter *StableRateLimiter) { + rateLimiter = &StableRateLimiter{ threshold: threshold, currentThreshold: threshold, - refillPeroid: refillPeroid, + refillPeriod: refillPeriod, broadcastChannel: make(chan bool), } return rateLimiter } -func (limiter *stableRateLimiter) start() { +// Start to refill the bucket periodically. +func (limiter *StableRateLimiter) Start() { limiter.quitChannel = make(chan bool) quitChannel := limiter.quitChannel go func() { @@ -46,7 +65,7 @@ func (limiter *stableRateLimiter) start() { return default: atomic.StoreInt64(&limiter.currentThreshold, limiter.threshold) - time.Sleep(limiter.refillPeroid) + time.Sleep(limiter.refillPeriod) close(limiter.broadcastChannel) limiter.broadcastChannel = make(chan bool) } @@ -54,7 +73,8 @@ func (limiter *stableRateLimiter) start() { }() } -func (limiter *stableRateLimiter) acquire() (blocked bool) { +// Acquire a token from the bucket, returns true if the bucket is exhausted. +func (limiter *StableRateLimiter) Acquire() (blocked bool) { permit := atomic.AddInt64(&limiter.currentThreshold, -1) if permit < 0 { blocked = true @@ -66,21 +86,22 @@ func (limiter *stableRateLimiter) acquire() (blocked bool) { return blocked } -func (limiter *stableRateLimiter) stop() { +// Stop the rate limiter. +func (limiter *StableRateLimiter) Stop() { close(limiter.quitChannel) } // ErrParsingRampUpRate is the error returned if the format of rampUpRate is invalid. var ErrParsingRampUpRate = errors.New("ratelimiter: invalid format of rampUpRate, try \"1\" or \"1/1s\"") -// rampUpRateLimiter uses the token bucket algorithm. +// A RampUpRateLimiter uses the token bucket algorithm. // the threshold is updated according to the warm up rate. // the bucket is refilled according to the refill period, no burst is allowed. -type rampUpRateLimiter struct { +type RampUpRateLimiter struct { maxThreshold int64 nextThreshold int64 currentThreshold int64 - refillPeroid time.Duration + refillPeriod time.Duration rampUpRate string rampUpStep int64 rampUpPeroid time.Duration @@ -89,13 +110,15 @@ type rampUpRateLimiter struct { quitChannel chan bool } -func newRampUpRateLimiter(maxThreshold int64, rampUpRate string, refillPeroid time.Duration) (rateLimiter *rampUpRateLimiter, err error) { - rateLimiter = &rampUpRateLimiter{ +// NewRampUpRateLimiter returns a RampUpRateLimiter. +// Valid formats of rampUpRate are "1", "1/1s". +func NewRampUpRateLimiter(maxThreshold int64, rampUpRate string, refillPeriod time.Duration) (rateLimiter *RampUpRateLimiter, err error) { + rateLimiter = &RampUpRateLimiter{ maxThreshold: maxThreshold, nextThreshold: 0, currentThreshold: 0, rampUpRate: rampUpRate, - refillPeroid: refillPeroid, + refillPeriod: refillPeriod, broadcastChannel: make(chan bool), } rateLimiter.rampUpStep, rateLimiter.rampUpPeroid, err = rateLimiter.parseRampUpRate(rateLimiter.rampUpRate) @@ -105,7 +128,7 @@ func newRampUpRateLimiter(maxThreshold int64, rampUpRate string, refillPeroid ti return rateLimiter, nil } -func (limiter *rampUpRateLimiter) parseRampUpRate(rampUpRate string) (rampUpStep int64, rampUpPeroid time.Duration, err error) { +func (limiter *RampUpRateLimiter) parseRampUpRate(rampUpRate string) (rampUpStep int64, rampUpPeroid time.Duration, err error) { if strings.Contains(rampUpRate, "/") { tmp := strings.Split(rampUpRate, "/") if len(tmp) != 2 { @@ -130,7 +153,8 @@ func (limiter *rampUpRateLimiter) parseRampUpRate(rampUpRate string) (rampUpStep return rampUpStep, rampUpPeroid, nil } -func (limiter *rampUpRateLimiter) start() { +// Start to refill the bucket periodically. +func (limiter *RampUpRateLimiter) Start() { limiter.quitChannel = make(chan bool) quitChannel := limiter.quitChannel // bucket updater @@ -141,7 +165,7 @@ func (limiter *rampUpRateLimiter) start() { return default: atomic.StoreInt64(&limiter.currentThreshold, limiter.nextThreshold) - time.Sleep(limiter.refillPeroid) + time.Sleep(limiter.refillPeriod) close(limiter.broadcastChannel) limiter.broadcastChannel = make(chan bool) } @@ -169,7 +193,8 @@ func (limiter *rampUpRateLimiter) start() { }() } -func (limiter *rampUpRateLimiter) acquire() (blocked bool) { +// Acquire a token from the bucket, returns true if the bucket is exhausted. +func (limiter *RampUpRateLimiter) Acquire() (blocked bool) { permit := atomic.AddInt64(&limiter.currentThreshold, -1) if permit < 0 { blocked = true @@ -181,7 +206,8 @@ func (limiter *rampUpRateLimiter) acquire() (blocked bool) { return blocked } -func (limiter *rampUpRateLimiter) stop() { +// Stop the rate limiter. +func (limiter *RampUpRateLimiter) Stop() { limiter.nextThreshold = 0 close(limiter.quitChannel) } diff --git a/ratelimiter_test.go b/ratelimiter_test.go index e95315d..6751d2c 100644 --- a/ratelimiter_test.go +++ b/ratelimiter_test.go @@ -6,31 +6,31 @@ import ( ) func TestStableRateLimiter(t *testing.T) { - rateLimiter := newStableRateLimiter(1, 10*time.Millisecond) - rateLimiter.start() - blocked := rateLimiter.acquire() + rateLimiter := NewStableRateLimiter(1, 10*time.Millisecond) + rateLimiter.Start() + blocked := rateLimiter.Acquire() if blocked { t.Error("Unexpected blocked by rate limiter") } - blocked = rateLimiter.acquire() + blocked = rateLimiter.Acquire() if !blocked { t.Error("Should be blocked") } - rateLimiter.stop() + rateLimiter.Stop() } func TestRampUpRateLimiter(t *testing.T) { - rateLimiter, _ := newRampUpRateLimiter(100, "10/200ms", 100*time.Millisecond) - rateLimiter.start() + rateLimiter, _ := NewRampUpRateLimiter(100, "10/200ms", 100*time.Millisecond) + rateLimiter.Start() time.Sleep(110 * time.Millisecond) for i := 0; i < 10; i++ { - blocked := rateLimiter.acquire() + blocked := rateLimiter.Acquire() if blocked { t.Error("Unexpected blocked by rate limiter") } } - blocked := rateLimiter.acquire() + blocked := rateLimiter.Acquire() if !blocked { t.Error("Should be blocked") } @@ -39,39 +39,39 @@ func TestRampUpRateLimiter(t *testing.T) { // now, the threshold is 20 for i := 0; i < 20; i++ { - blocked := rateLimiter.acquire() + blocked := rateLimiter.Acquire() if blocked { t.Error("Unexpected blocked by rate limiter") } } - blocked = rateLimiter.acquire() + blocked = rateLimiter.Acquire() if !blocked { t.Error("Should be blocked") } - rateLimiter.stop() + rateLimiter.Stop() } func TestParseRampUpRate(t *testing.T) { - rateLimiter := &rampUpRateLimiter{} - rampUpStep, rampUpPeroid, _ := rateLimiter.parseRampUpRate("100") + rateLimiter := &RampUpRateLimiter{} + rampUpStep, rampUpPeriod, _ := rateLimiter.parseRampUpRate("100") if rampUpStep != 100 { t.Error("Wrong rampUpStep, expected: 100, was:", rampUpStep) } - if rampUpPeroid != time.Second { - t.Error("Wrong rampUpPeroid, expected: 1s, was:", rampUpPeroid) + if rampUpPeriod != time.Second { + t.Error("Wrong rampUpPeriod, expected: 1s, was:", rampUpPeriod) } - rampUpStep, rampUpPeroid, _ = rateLimiter.parseRampUpRate("200/10s") + rampUpStep, rampUpPeriod, _ = rateLimiter.parseRampUpRate("200/10s") if rampUpStep != 200 { t.Error("Wrong rampUpStep, expected: 200, was:", rampUpStep) } - if rampUpPeroid != 10*time.Second { - t.Error("Wrong rampUpPeroid, expected: 10s, was:", rampUpPeroid) + if rampUpPeriod != 10*time.Second { + t.Error("Wrong rampUpPeriod, expected: 10s, was:", rampUpPeriod) } } func TestParseInvalidRampUpRate(t *testing.T) { - rateLimiter := &rampUpRateLimiter{} + rateLimiter := &RampUpRateLimiter{} _, _, err := rateLimiter.parseRampUpRate("A/1m") if err == nil || err != ErrParsingRampUpRate { diff --git a/runner.go b/runner.go index d8e2f1a..8f21b71 100644 --- a/runner.go +++ b/runner.go @@ -41,14 +41,12 @@ type runner struct { client client nodeID string hatchType string - rateLimiter rateLimiter + rateLimiter RateLimiter rateLimitEnabled bool stats *requestStats - // cache of current time in second - now int64 } -func newRunner(tasks []*Task, rateLimiter rateLimiter, hatchType string) (r *runner) { +func newRunner(tasks []*Task, rateLimiter RateLimiter, hatchType string) (r *runner) { r = &runner{ tasks: tasks, hatchType: hatchType, @@ -61,10 +59,6 @@ func newRunner(tasks []*Task, rateLimiter rateLimiter, hatchType string) (r *run r.rateLimiter = rateLimiter } - if hatchType != "asap" && hatchType != "smooth" { - log.Fatalf("Wrong hatch-type, expected asap or smooth, was %s\n", hatchType) - } - r.stats = newRequestStats() return r @@ -122,7 +116,7 @@ func (r *runner) spawnGoRoutines(spawnCount int, quit chan bool) { return default: if r.rateLimitEnabled { - blocked := r.rateLimiter.acquire() + blocked := r.rateLimiter.Acquire() if !blocked { r.safeRun(fn) } @@ -170,7 +164,7 @@ func (r *runner) stop() { // those goroutines will exit when r.safeRun returns close(r.stopChannel) if r.rateLimitEnabled { - r.rateLimiter.stop() + r.rateLimiter.Stop() } } @@ -202,7 +196,7 @@ func (r *runner) onHatchMessage(msg *message) { Events.Publish("boomer:hatch", workers, hatchRate) if r.rateLimitEnabled { - r.rateLimiter.start() + r.rateLimiter.Start() } r.startHatching(workers, hatchRate) } @@ -210,18 +204,14 @@ func (r *runner) onHatchMessage(msg *message) { // Runner acts as a state machine, and runs in one goroutine without any lock. func (r *runner) onMessage(msg *message) { - if msg.Type == "quit" { - log.Println("Got quit message from master, shutting down...") - r.state = stateQuitting - Events.Publish("boomer:quit") - os.Exit(0) - } - switch r.state { case stateInit: - if msg.Type == "hatch" { + switch msg.Type { + case "hatch": r.state = stateHatching r.onHatchMessage(msg) + case "quit": + Events.Publish("boomer:quit") } case stateHatching: fallthrough @@ -237,11 +227,20 @@ func (r *runner) onMessage(msg *message) { log.Println("Recv stop message from master, all the goroutines are stopped") r.client.sendChannel() <- newMessage("client_stopped", nil, r.nodeID) r.client.sendChannel() <- newMessage("client_ready", nil, r.nodeID) + case "quit": + r.stop() + log.Println("Recv quit message from master, all the goroutines are stopped") + Events.Publish("boomer:quit") + r.state = stateInit } case stateStopped: - if msg.Type == "hatch" { + switch msg.Type { + case "hatch": r.state = stateHatching r.onHatchMessage(msg) + case "quit": + Events.Publish("boomer:quit") + r.state = stateInit } } } @@ -288,17 +287,5 @@ func (r *runner) getReady() { } }() - go func() { - var ticker = time.NewTicker(time.Second) - for { - select { - case <-ticker.C: - r.now = time.Now().Unix() - case <-r.shutdownSignal: - return - } - } - }() - Events.Subscribe("boomer:quit", r.onQuiting) } diff --git a/runner_test.go b/runner_test.go index 0eea0a1..cbe3178 100644 --- a/runner_test.go +++ b/runner_test.go @@ -7,11 +7,10 @@ import ( ) func TestSafeRun(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") - defaultRunner.safeRun(func() { + runner := newRunner(nil, nil, "asap") + runner.safeRun(func() { panic("Runner will catch this panic") }) - defaultRunner = nil } func TestSpawnGoRoutines(t *testing.T) { @@ -30,7 +29,7 @@ func TestSpawnGoRoutines(t *testing.T) { Name: "TaskB", } tasks := []*Task{taskA, taskB} - rateLimiter := newStableRateLimiter(100, time.Second) + rateLimiter := NewStableRateLimiter(100, time.Second) runner := newRunner(tasks, rateLimiter, "asap") defer runner.close() @@ -327,7 +326,7 @@ func TestGetReady(t *testing.T) { defer server.close() server.start() - rateLimiter := newStableRateLimiter(100, time.Second) + rateLimiter := NewStableRateLimiter(100, time.Second) r := newRunner(nil, rateLimiter, "asap") r.masterHost = masterHost r.masterPort = masterPort diff --git a/stats.go b/stats.go index c18ab59..7fe0580 100644 --- a/stats.go +++ b/stats.go @@ -197,9 +197,7 @@ func (s *statsEntry) log(responseTime int64, contentLength int64) { } func (s *statsEntry) logTimeOfRequest() { - // 'now' is updated by another goroutine - // make a copy to avoid race condition - key := defaultRunner.now + key := time.Now().Unix() _, ok := s.numReqsPerSec[key] if !ok { s.numReqsPerSec[key] = 1 diff --git a/stats_test.go b/stats_test.go index 47b9c17..da851be 100644 --- a/stats_test.go +++ b/stats_test.go @@ -6,7 +6,6 @@ import ( ) func TestLogRequest(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 2, 30) newStats.logRequest("http", "success", 3, 40) @@ -46,21 +45,16 @@ func TestLogRequest(t *testing.T) { if newStats.total.totalContentLength != 130 { t.Error("newStats.total.totalContentLength is wrong, expected: 130, got:", newStats.total.totalContentLength) } - - defaultRunner = nil } func BenchmarkLogRequest(b *testing.B) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() for i := 0; i < b.N; i++ { newStats.logRequest("http", "success", 2, 30) } - defaultRunner = nil } func TestRoundedResponseTime(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 147, 1) newStats.logRequest("http", "success", 3432, 1) @@ -83,12 +77,9 @@ func TestRoundedResponseTime(t *testing.T) { if val, ok := responseTimes[59000]; !ok || val != 1 { t.Error("Rounded response time should be", 59000) } - - defaultRunner = nil } func TestLogError(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logError("http", "failure", "500 error") newStats.logError("http", "failure", "400 error") @@ -121,22 +112,18 @@ func TestLogError(t *testing.T) { t.Error("Error occurences is wrong, expected: 2, got:", err400.occurences) } - defaultRunner = nil } func BenchmarkLogError(b *testing.B) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() for i := 0; i < b.N; i++ { // LogError use md5 to calculate hash keys, it may slow down the only goroutine, // which consumes both requestSuccessChannel and requestFailureChannel. newStats.logError("http", "failure", "500 error") } - defaultRunner = nil } func TestClearAll(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 1, 20) newStats.clearAll() @@ -144,11 +131,9 @@ func TestClearAll(t *testing.T) { if newStats.total.numRequests != 0 { t.Error("After clearAll(), newStats.total.numRequests is wrong, expected: 0, got:", newStats.total.numRequests) } - defaultRunner = nil } func TestClearAllByChannel(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.start() defer newStats.close() @@ -158,11 +143,9 @@ func TestClearAllByChannel(t *testing.T) { if newStats.total.numRequests != 0 { t.Error("After clearAll(), newStats.total.numRequests is wrong, expected: 0, got:", newStats.total.numRequests) } - defaultRunner = nil } func TestSerializeStats(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 1, 20) @@ -185,12 +168,9 @@ func TestSerializeStats(t *testing.T) { if first["num_failures"].(int64) != int64(0) { t.Error("The num_failures is wrong, expected:", 0, "got:", first["num_failures"].(int64)) } - - defaultRunner = nil } func TestSerializeErrors(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logError("http", "failure", "500 error") newStats.logError("http", "failure", "400 error") @@ -214,11 +194,9 @@ func TestSerializeErrors(t *testing.T) { } } } - defaultRunner = nil } func TestCollectReportData(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 2, 30) newStats.logError("http", "failure", "500 error") @@ -233,11 +211,9 @@ func TestCollectReportData(t *testing.T) { if _, ok := result["errors"]; !ok { t.Error("Key stats not found") } - defaultRunner = nil } func TestStatsStart(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.start() defer newStats.close() @@ -266,6 +242,4 @@ func TestStatsStart(t *testing.T) { } } end: - - defaultRunner = nil } diff --git a/utils.go b/utils.go index 7c0a8bf..ffe567b 100644 --- a/utils.go +++ b/utils.go @@ -4,8 +4,10 @@ import ( "crypto/md5" "fmt" "io" + "log" "math" "os" + "runtime/pprof" "strings" "time" @@ -47,3 +49,44 @@ func getNodeID() (nodeID string) { func Now() int64 { return time.Now().UnixNano() / int64(time.Millisecond) } + +// StartMemoryProfile starts memory profiling and save the results in file. +func StartMemoryProfile(file string, duration time.Duration) { + f, err := os.Create(file) + if err != nil { + log.Fatal(err) + } + + log.Println("Start memory profiling for", duration) + time.AfterFunc(duration, func() { + err = pprof.WriteHeapProfile(f) + if err != nil { + log.Println(err) + return + } + f.Close() + log.Println("Stop memory profiling after", duration) + }) +} + +// StartCPUProfile starts cpu profiling and save the results in file. +func StartCPUProfile(file string, duration time.Duration) { + f, err := os.Create(file) + if err != nil { + log.Fatal(err) + } + + log.Println("Start cpu profiling for", duration) + err = pprof.StartCPUProfile(f) + if err != nil { + log.Println(err) + f.Close() + return + } + + time.AfterFunc(duration, func() { + pprof.StopCPUProfile() + f.Close() + log.Println("Stop CPU profiling after", duration) + }) +} diff --git a/utils_test.go b/utils_test.go index 131081e..17fb998 100644 --- a/utils_test.go +++ b/utils_test.go @@ -4,6 +4,7 @@ import ( "os" "regexp" "testing" + "time" ) func TestRound(t *testing.T) { @@ -59,3 +60,29 @@ func TestNow(t *testing.T) { t.Error("Invalid format of timestamp in milliseconds") } } + +func TestStartMemoryProfile(t *testing.T) { + if _, err := os.Stat("mem.pprof"); os.IsExist(err) { + os.Remove("mem.pprof") + } + StartMemoryProfile("mem.pprof", 2*time.Second) + time.Sleep(2100 * time.Millisecond) + if _, err := os.Stat("mem.pprof"); os.IsNotExist(err) { + t.Error("File mem.pprof is not generated") + } else { + os.Remove("mem.pprof") + } +} + +func TestStartCPUProfile(t *testing.T) { + if _, err := os.Stat("cpu.pprof"); os.IsExist(err) { + os.Remove("cpu.pprof") + } + StartCPUProfile("cpu.pprof", 2*time.Second) + time.Sleep(2100 * time.Millisecond) + if _, err := os.Stat("cpu.pprof"); os.IsNotExist(err) { + t.Error("File cpu.pprof is not generated") + } else { + os.Remove("cpu.pprof") + } +}