diff --git a/boomer.go b/boomer.go index fec9746..896ff05 100644 --- a/boomer.go +++ b/boomer.go @@ -3,45 +3,102 @@ package boomer import ( "flag" "log" - "math" "os" "os/signal" - "runtime/pprof" "strings" - "sync" - "sync/atomic" "syscall" "time" + + "github.com/asaskevich/EventBus" ) -var masterHost string -var masterPort int +// Events is the global event bus instance. +var Events = EventBus.New() -var maxRPS int64 -var requestIncreaseRate string -var hatchType string -var runTasks string -var memoryProfile string -var memoryProfileDuration time.Duration -var cpuProfile string -var cpuProfileDuration time.Duration +var defaultBoomer *Boomer -var defaultRunner *runner +// A Boomer is used to run tasks. +type Boomer struct { + masterHost string + masterPort int -var initiated uint32 -var initMutex = sync.Mutex{} + hatchType string + rateLimiter RateLimiter + runner *runner +} -// Init boomer -func initBoomer() { - if atomic.LoadUint32(&initiated) == 1 { - panic("Don't call boomer.Run() more than once.") +// NewBoomer returns a new Boomer. +func NewBoomer(masterHost string, masterPort int) *Boomer { + return &Boomer{ + masterHost: masterHost, + masterPort: masterPort, + hatchType: "asap", } +} - // TODO: to be removed - initLegacyEventHandlers() +// SetRateLimiter allows user to use their own rate limiter. +// It must be called before the test is started. +func (b *Boomer) SetRateLimiter(rateLimiter RateLimiter) { + b.rateLimiter = rateLimiter +} + +// SetHatchType only accepts "asap" or "smooth". +// "asap" means spawning goroutines as soon as possible when the test is started. +// "smooth" means a constant pace. +func (b *Boomer) SetHatchType(hatchType string) { + if hatchType != "asap" && hatchType != "smooth" { + log.Printf("Wrong hatch-type, expected asap or smooth, was %s\n", hatchType) + return + } + b.hatchType = hatchType +} + +func (b *Boomer) setRunner(runner *runner) { + b.runner = runner +} + +// Run accepts a slice of Task and connects to the locust master. +func (b *Boomer) Run(tasks ...*Task) { + b.runner = newRunner(tasks, b.rateLimiter, b.hatchType) + b.runner.masterHost = b.masterHost + b.runner.masterPort = b.masterPort + b.runner.getReady() +} + +// RecordSuccess reports a success +func (b *Boomer) RecordSuccess(requestType, name string, responseTime int64, responseLength int64) { + b.runner.stats.requestSuccessChannel <- &requestSuccess{ + requestType: requestType, + name: name, + responseTime: responseTime, + responseLength: responseLength, + } +} - // done - atomic.StoreUint32(&initiated, 1) +// RecordFailure reports a failure. +func (b *Boomer) RecordFailure(requestType, name string, responseTime int64, exception string) { + b.runner.stats.requestFailureChannel <- &requestFailure{ + requestType: requestType, + name: name, + responseTime: responseTime, + error: exception, + } +} + +// Quit will send a quit message to the master. +func (b *Boomer) Quit() { + Events.Publish("boomer:quit") + var ticker = time.NewTicker(3 * time.Second) + for { + // wait for quit message is sent to master + select { + case <-b.runner.client.disconnectedChannel(): + return + case <-ticker.C: + log.Println("Timeout waiting for sending quit message to master, boomer will quit any way.") + return + } + } } // Run tasks without connecting to the master. @@ -61,26 +118,8 @@ func runTasksForTest(tasks ...*Task) { } } -func createRateLimiter(maxRPS int64, requestIncreaseRate string) (rateLimiter rateLimiter, err error) { - if requestIncreaseRate != "-1" { - if maxRPS > 0 { - log.Println("The max RPS that boomer may generate is limited to", maxRPS, "with a increase rate", requestIncreaseRate) - rateLimiter, err = newRampUpRateLimiter(maxRPS, requestIncreaseRate, time.Second) - } else { - log.Println("The max RPS that boomer may generate is limited by a increase rate", requestIncreaseRate) - rateLimiter, err = newRampUpRateLimiter(math.MaxInt64, requestIncreaseRate, time.Second) - } - } else { - if maxRPS > 0 { - log.Println("The max RPS that boomer may generate is limited to", maxRPS) - rateLimiter = newStableRateLimiter(maxRPS, time.Second) - } - } - return rateLimiter, err -} - -// Run accepts a slice of Task and connects -// to a locust master. +// Run accepts a slice of Task and connects to a locust master. +// It's a convenience function to use the defaultBoomer. func Run(tasks ...*Task) { if !flag.Parsed() { flag.Parse() @@ -91,85 +130,43 @@ func Run(tasks ...*Task) { return } - // init boomer - initMutex.Lock() - initBoomer() - initMutex.Unlock() - - rateLimiter, err := createRateLimiter(maxRPS, requestIncreaseRate) - if err != nil { - log.Fatalf("Failed to create rate limiter, %v\n", err) - } - - defaultRunner = newRunner(tasks, rateLimiter, hatchType) - defaultRunner.masterHost = masterHost - defaultRunner.masterPort = masterPort - defaultRunner.getReady() + defaultBoomer = NewBoomer(masterHost, masterPort) + initLegacyEventHandlers() if memoryProfile != "" { - startMemoryProfile(memoryProfile, memoryProfileDuration) + StartMemoryProfile(memoryProfile, memoryProfileDuration) } if cpuProfile != "" { - startCPUProfile(cpuProfile, cpuProfileDuration) + StartCPUProfile(cpuProfile, cpuProfileDuration) } + rateLimiter, err := createRateLimiter(maxRPS, requestIncreaseRate) + if err != nil { + log.Fatalf("%v\n", err) + } + defaultBoomer.SetRateLimiter(rateLimiter) + defaultBoomer.hatchType = hatchType + + defaultBoomer.Run(tasks...) + c := make(chan os.Signal) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) <-c - Events.Publish("boomer:quit") + defaultBoomer.Quit() - // wait for quit message is sent to master - <-defaultRunner.client.disconnectedChannel() log.Println("shut down") } -func startMemoryProfile(file string, duration time.Duration) { - f, err := os.Create(file) - if err != nil { - log.Fatal(err) - } - time.AfterFunc(duration, func() { - err = pprof.WriteHeapProfile(f) - if err != nil { - log.Println(err) - return - } - f.Close() - log.Println("Stop memory profiling after", duration) - }) -} - -func startCPUProfile(file string, duration time.Duration) { - f, err := os.Create(file) - if err != nil { - log.Fatal(err) - } - - err = pprof.StartCPUProfile(f) - if err != nil { - log.Println(err) - f.Close() - return - } - - time.AfterFunc(duration, func() { - pprof.StopCPUProfile() - f.Close() - log.Println("Stop CPU profiling after", duration) - }) +// RecordSuccess reports a success. +// It's a convenience function to use the defaultBoomer. +func RecordSuccess(requestType, name string, responseTime int64, responseLength int64) { + defaultBoomer.RecordSuccess(requestType, name, responseTime, responseLength) } -func init() { - flag.Int64Var(&maxRPS, "max-rps", 0, "Max RPS that boomer can generate, disabled by default.") - flag.StringVar(&requestIncreaseRate, "request-increase-rate", "-1", "Request increase rate, disabled by default.") - flag.StringVar(&hatchType, "hatch-type", "asap", "How to create goroutines according to hatch rate, 'asap' will do it as soon as possible while 'smooth' means a constant pace.") - flag.StringVar(&runTasks, "run-tasks", "", "Run tasks without connecting to the master, multiply tasks is separated by comma. Usually, it's for debug purpose.") - flag.StringVar(&masterHost, "master-host", "127.0.0.1", "Host or IP address of locust master for distributed load testing.") - flag.IntVar(&masterPort, "master-port", 5557, "The port to connect to that is used by the locust master for distributed load testing.") - flag.StringVar(&memoryProfile, "mem-profile", "", "Enable memory profiling.") - flag.DurationVar(&memoryProfileDuration, "mem-profile-duration", 30*time.Second, "Memory profile duration.") - flag.StringVar(&cpuProfile, "cpu-profile", "", "Enable CPU profiling.") - flag.DurationVar(&cpuProfileDuration, "cpu-profile-duration", 30*time.Second, "CPU profile duration.") +// RecordFailure reports a failure. +// It's a convenience function to use the defaultBoomer. +func RecordFailure(requestType, name string, responseTime int64, exception string) { + defaultBoomer.RecordFailure(requestType, name, responseTime, exception) } diff --git a/boomer_test.go b/boomer_test.go index c4e084c..e773422 100644 --- a/boomer_test.go +++ b/boomer_test.go @@ -2,25 +2,10 @@ package boomer import ( "math" - "os" "testing" "time" ) -func TestInitBoomer(t *testing.T) { - initBoomer() - defer Events.Unsubscribe("request_success", legacySuccessHandler) - defer Events.Unsubscribe("request_failure", legacyFailureHandler) - - defer func() { - err := recover() - if err == nil { - t.Error("It should panic if initBoomer is called more than once.") - } - }() - initBoomer() -} - func TestRunTasksForTest(t *testing.T) { count := 0 taskA := &Task{ @@ -36,6 +21,7 @@ func TestRunTasksForTest(t *testing.T) { }, } runTasks = "increaseCount,foobar" + runTasksForTest(taskA, taskWithoutName) if count != 1 { @@ -45,7 +31,7 @@ func TestRunTasksForTest(t *testing.T) { func TestCreateRatelimiter(t *testing.T) { rateLimiter, _ := createRateLimiter(100, "-1") - if stableRateLimiter, ok := rateLimiter.(*stableRateLimiter); !ok { + if stableRateLimiter, ok := rateLimiter.(*StableRateLimiter); !ok { t.Error("Expected stableRateLimiter") } else { if stableRateLimiter.threshold != 100 { @@ -54,7 +40,7 @@ func TestCreateRatelimiter(t *testing.T) { } rateLimiter, _ = createRateLimiter(0, "1") - if rampUpRateLimiter, ok := rateLimiter.(*rampUpRateLimiter); !ok { + if rampUpRateLimiter, ok := rateLimiter.(*RampUpRateLimiter); !ok { t.Error("Expected rampUpRateLimiter") } else { if rampUpRateLimiter.maxThreshold != math.MaxInt64 { @@ -66,7 +52,7 @@ func TestCreateRatelimiter(t *testing.T) { } rateLimiter, _ = createRateLimiter(10, "2/2s") - if rampUpRateLimiter, ok := rateLimiter.(*rampUpRateLimiter); !ok { + if rampUpRateLimiter, ok := rateLimiter.(*RampUpRateLimiter); !ok { t.Error("Expected rampUpRateLimiter") } else { if rampUpRateLimiter.maxThreshold != 10 { @@ -84,28 +70,35 @@ func TestCreateRatelimiter(t *testing.T) { } } -func TestStartMemoryProfile(t *testing.T) { - if _, err := os.Stat("mem.pprof"); os.IsExist(err) { - os.Remove("mem.pprof") +func TestRecordSuccess(t *testing.T) { + defaultBoomer = NewBoomer("127.0.0.1", 5557) + defaultBoomer.runner = newRunner(nil, nil, "asap") + RecordSuccess("http", "foo", int64(1), int64(10)) + + requestSuccessMsg := <-defaultBoomer.runner.stats.requestSuccessChannel + if requestSuccessMsg.requestType != "http" { + t.Error("Expected: http, got:", requestSuccessMsg.requestType) } - startMemoryProfile("mem.pprof", 2*time.Second) - time.Sleep(2100 * time.Millisecond) - if _, err := os.Stat("mem.pprof"); os.IsNotExist(err) { - t.Error("File mem.pprof is not generated") - } else { - os.Remove("mem.pprof") + if requestSuccessMsg.responseTime != int64(1) { + t.Error("Expected: 1, got:", requestSuccessMsg.responseTime) } + defaultBoomer = nil } -func TestStartCPUProfile(t *testing.T) { - if _, err := os.Stat("cpu.pprof"); os.IsExist(err) { - os.Remove("cpu.pprof") +func TestRecordFailure(t *testing.T) { + defaultBoomer = NewBoomer("127.0.0.1", 5557) + defaultBoomer.runner = newRunner(nil, nil, "asap") + RecordFailure("udp", "bar", int64(2), "udp error") + + requestFailureMsg := <-defaultBoomer.runner.stats.requestFailureChannel + if requestFailureMsg.requestType != "udp" { + t.Error("Expected: udp, got:", requestFailureMsg.requestType) } - startCPUProfile("cpu.pprof", 2*time.Second) - time.Sleep(2100 * time.Millisecond) - if _, err := os.Stat("cpu.pprof"); os.IsNotExist(err) { - t.Error("File cpu.pprof is not generated") - } else { - os.Remove("cpu.pprof") + if requestFailureMsg.responseTime != int64(2) { + t.Error("Expected: 2, got:", requestFailureMsg.responseTime) + } + if requestFailureMsg.error != "udp error" { + t.Error("Expected: udp error, got:", requestFailureMsg.error) } + defaultBoomer = nil } diff --git a/events.go b/events.go deleted file mode 100644 index 12dda0e..0000000 --- a/events.go +++ /dev/null @@ -1,28 +0,0 @@ -package boomer - -import ( - "github.com/asaskevich/EventBus" -) - -// Events is core event bus instance of boomer -var Events = EventBus.New() - -// RecordSuccess reports a success -func RecordSuccess(requestType, name string, responseTime int64, responseLength int64) { - defaultRunner.stats.requestSuccessChannel <- &requestSuccess{ - requestType: requestType, - name: name, - responseTime: responseTime, - responseLength: responseLength, - } -} - -// RecordFailure reports a failure -func RecordFailure(requestType, name string, responseTime int64, exception string) { - defaultRunner.stats.requestFailureChannel <- &requestFailure{ - requestType: requestType, - name: name, - responseTime: responseTime, - error: exception, - } -} diff --git a/events_test.go b/events_test.go deleted file mode 100644 index cf90918..0000000 --- a/events_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package boomer - -import "testing" - -func TestRecordSuccess(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") - RecordSuccess("http", "foo", int64(1), int64(10)) - - requestSuccessMsg := <-defaultRunner.stats.requestSuccessChannel - if requestSuccessMsg.requestType != "http" { - t.Error("Expected: http, got:", requestSuccessMsg.requestType) - } - if requestSuccessMsg.responseTime != int64(1) { - t.Error("Expected: 1, got:", requestSuccessMsg.responseTime) - } - defaultRunner = nil -} - -func TestRecordFailure(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") - RecordFailure("udp", "bar", int64(2), "udp error") - - requestFailureMsg := <-defaultRunner.stats.requestFailureChannel - if requestFailureMsg.requestType != "udp" { - t.Error("Expected: udp, got:", requestFailureMsg.requestType) - } - if requestFailureMsg.responseTime != int64(2) { - t.Error("Expected: 2, got:", requestFailureMsg.responseTime) - } - if requestFailureMsg.error != "udp error" { - t.Error("Expected: udp error, got:", requestFailureMsg.error) - } - defaultRunner = nil -} diff --git a/examples/main.go b/examples/main.go index cea265e..9518c1e 100644 --- a/examples/main.go +++ b/examples/main.go @@ -2,33 +2,52 @@ package main import ( "log" + "os" + "os/signal" + "sync" + "syscall" "time" "github.com/myzhan/boomer" ) func foo() { - start := boomer.Now() time.Sleep(100 * time.Millisecond) elapsed := boomer.Now() - start // Report your test result as a success, if you write it in python, it will looks like this // events.request_success.fire(request_type="http", name="foo", response_time=100, response_length=10) - boomer.RecordSuccess("http", "foo", elapsed, int64(10)) + globalBoomer.RecordSuccess("http", "foo", elapsed, int64(10)) } func bar() { - start := boomer.Now() time.Sleep(100 * time.Millisecond) elapsed := boomer.Now() - start // Report your test result as a failure, if you write it in python, it will looks like this // events.request_failure.fire(request_type="udp", name="bar", response_time=100, exception=Exception("udp error")) - boomer.RecordFailure("udp", "bar", elapsed, "udp error") + globalBoomer.RecordFailure("udp", "bar", elapsed, "udp error") +} + +func waitForQuit() { + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + <-c + globalBoomer.Quit() + wg.Done() + }() + + wg.Wait() } +var globalBoomer = boomer.NewBoomer("127.0.0.1", 5557) + func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) @@ -40,9 +59,12 @@ func main() { task2 := &boomer.Task{ Name: "bar", - Weight: 20, + Weight: 30, Fn: bar, } - boomer.Run(task1, task2) + globalBoomer.Run(task1, task2) + + waitForQuit() + log.Println("shut down") } diff --git a/examples/ratelimit/maxrps/main.go b/examples/ratelimit/maxrps/main.go new file mode 100644 index 0000000..fea96d5 --- /dev/null +++ b/examples/ratelimit/maxrps/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "log" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/myzhan/boomer" +) + +func foo() { + start := boomer.Now() + time.Sleep(100 * time.Millisecond) + elapsed := boomer.Now() - start + + // Report your test result as a success, if you write it in python, it will looks like this + // events.request_success.fire(request_type="http", name="foo", response_time=100, response_length=10) + globalBoomer.RecordSuccess("http", "foo", elapsed, int64(10)) +} + +func waitForQuit() { + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + <-c + globalBoomer.Quit() + wg.Done() + }() + + wg.Wait() +} + +var globalBoomer = boomer.NewBoomer("127.0.0.1", 5557) + +func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + task1 := &boomer.Task{ + Name: "foo", + Weight: 10, + Fn: foo, + } + + ratelimiter := boomer.NewStableRateLimiter(100, time.Second) + log.Println("the max rps is limited to 100/s.") + globalBoomer.SetRateLimiter(ratelimiter) + + globalBoomer.Run(task1) + + waitForQuit() + log.Println("shut down") +} diff --git a/examples/ratelimit/rampup/main.go b/examples/ratelimit/rampup/main.go new file mode 100644 index 0000000..dec10d8 --- /dev/null +++ b/examples/ratelimit/rampup/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "log" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/myzhan/boomer" +) + +func foo() { + start := boomer.Now() + time.Sleep(100 * time.Millisecond) + elapsed := boomer.Now() - start + + // Report your test result as a success, if you write it in python, it will looks like this + // events.request_success.fire(request_type="http", name="foo", response_time=100, response_length=10) + globalBoomer.RecordSuccess("http", "foo", elapsed, int64(10)) +} + +func waitForQuit() { + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + <-c + globalBoomer.Quit() + wg.Done() + }() + + wg.Wait() +} + +var globalBoomer = boomer.NewBoomer("127.0.0.1", 5557) + +func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + task1 := &boomer.Task{ + Name: "foo", + Weight: 10, + Fn: foo, + } + + ratelimiter, _ := boomer.NewRampUpRateLimiter(1000, "100/1s", time.Second) + log.Println("the max rps is limited to 1000/s, with a rampup rate 100/1s.") + globalBoomer.SetRateLimiter(ratelimiter) + + globalBoomer.Run(task1) + + waitForQuit() + log.Println("shut down") +} diff --git a/legacy.go b/legacy.go index 16f1509..8bb8d5e 100644 --- a/legacy.go +++ b/legacy.go @@ -1,15 +1,49 @@ +// This file is kept to ensure backward compatibility. + package boomer import ( + "flag" "fmt" "log" + "math" "reflect" "sync" + "time" ) +var masterHost string +var masterPort int +var maxRPS int64 +var requestIncreaseRate string +var hatchType string +var runTasks string +var memoryProfile string +var memoryProfileDuration time.Duration +var cpuProfile string +var cpuProfileDuration time.Duration + var successRetiredWarning = &sync.Once{} var failureRetiredWarning = &sync.Once{} +func createRateLimiter(maxRPS int64, requestIncreaseRate string) (rateLimiter RateLimiter, err error) { + if requestIncreaseRate != "-1" { + if maxRPS > 0 { + log.Println("The max RPS that boomer may generate is limited to", maxRPS, "with a increase rate", requestIncreaseRate) + rateLimiter, err = NewRampUpRateLimiter(maxRPS, requestIncreaseRate, time.Second) + } else { + log.Println("The max RPS that boomer may generate is limited by a increase rate", requestIncreaseRate) + rateLimiter, err = NewRampUpRateLimiter(math.MaxInt64, requestIncreaseRate, time.Second) + } + } else { + if maxRPS > 0 { + log.Println("The max RPS that boomer may generate is limited to", maxRPS) + rateLimiter = NewStableRateLimiter(maxRPS, time.Second) + } + } + return rateLimiter, err +} + // According to locust, responseTime should be int64, in milliseconds. // But previous version of boomer required responseTime to be float64, so sad. func convertResponseTime(origin interface{}) int64 { @@ -28,27 +62,30 @@ func legacySuccessHandler(requestType string, name string, responseTime interfac successRetiredWarning.Do(func() { log.Println("boomer.Events.Publish(\"request_success\") is less performant and deprecated, use boomer.RecordSuccess() instead.") }) - defaultRunner.stats.requestSuccessChannel <- &requestSuccess{ - requestType: requestType, - name: name, - responseTime: convertResponseTime(responseTime), - responseLength: responseLength, - } + defaultBoomer.RecordSuccess(requestType, name, convertResponseTime(responseTime), responseLength) } func legacyFailureHandler(requestType string, name string, responseTime interface{}, exception string) { failureRetiredWarning.Do(func() { log.Println("boomer.Events.Publish(\"request_failure\") is less performant and deprecated, use boomer.RecordFailure() instead.") }) - defaultRunner.stats.requestFailureChannel <- &requestFailure{ - requestType: requestType, - name: name, - responseTime: convertResponseTime(responseTime), - error: exception, - } + defaultBoomer.RecordFailure(requestType, name, convertResponseTime(responseTime), exception) } func initLegacyEventHandlers() { Events.Subscribe("request_success", legacySuccessHandler) Events.Subscribe("request_failure", legacyFailureHandler) } + +func init() { + flag.Int64Var(&maxRPS, "max-rps", 0, "Max RPS that boomer can generate, disabled by default.") + flag.StringVar(&requestIncreaseRate, "request-increase-rate", "-1", "Request increase rate, disabled by default.") + flag.StringVar(&hatchType, "hatch-type", "asap", "How to create goroutines according to hatch rate, 'asap' will do it as soon as possible while 'smooth' means a constant pace.") + flag.StringVar(&runTasks, "run-tasks", "", "Run tasks without connecting to the master, multiply tasks is separated by comma. Usually, it's for debug purpose.") + flag.StringVar(&masterHost, "master-host", "127.0.0.1", "Host or IP address of locust master for distributed load testing.") + flag.IntVar(&masterPort, "master-port", 5557, "The port to connect to that is used by the locust master for distributed load testing.") + flag.StringVar(&memoryProfile, "mem-profile", "", "Enable memory profiling.") + flag.DurationVar(&memoryProfileDuration, "mem-profile-duration", 30*time.Second, "Memory profile duration.") + flag.StringVar(&cpuProfile, "cpu-profile", "", "Enable CPU profiling.") + flag.DurationVar(&cpuProfileDuration, "cpu-profile-duration", 30*time.Second, "CPU profile duration.") +} diff --git a/legacy_test.go b/legacy_test.go index 3b0b046..a299c41 100644 --- a/legacy_test.go +++ b/legacy_test.go @@ -24,13 +24,16 @@ func TestConvertResponseTime(t *testing.T) { func TestInitEvents(t *testing.T) { initLegacyEventHandlers() + defer Events.Unsubscribe("request_success", legacySuccessHandler) + defer Events.Unsubscribe("request_failure", legacyFailureHandler) - defaultRunner = newRunner(nil, nil, "asap") + defaultBoomer = NewBoomer("127.0.0.1", 5557) + defaultBoomer.runner = newRunner(nil, nil, "asap") Events.Publish("request_success", "http", "foo", int64(1), int64(10)) Events.Publish("request_failure", "udp", "bar", int64(2), "udp error") - requestSuccessMsg := <-defaultRunner.stats.requestSuccessChannel + requestSuccessMsg := <-defaultBoomer.runner.stats.requestSuccessChannel if requestSuccessMsg.requestType != "http" { t.Error("Expected: http, got:", requestSuccessMsg.requestType) } @@ -38,7 +41,7 @@ func TestInitEvents(t *testing.T) { t.Error("Expected: 1, got:", requestSuccessMsg.responseTime) } - requestFailureMsg := <-defaultRunner.stats.requestFailureChannel + requestFailureMsg := <-defaultBoomer.runner.stats.requestFailureChannel if requestFailureMsg.requestType != "udp" { t.Error("Expected: udp, got:", requestFailureMsg.requestType) } @@ -48,9 +51,4 @@ func TestInitEvents(t *testing.T) { if requestFailureMsg.error != "udp error" { t.Error("Expected: udp error, got:", requestFailureMsg.error) } - - Events.Unsubscribe("request_success", legacySuccessHandler) - Events.Unsubscribe("request_failure", legacyFailureHandler) - - defaultRunner = nil } diff --git a/ratelimiter.go b/ratelimiter.go index 607c36d..5cf98b9 100644 --- a/ratelimiter.go +++ b/ratelimiter.go @@ -9,34 +9,53 @@ import ( "time" ) -// runner uses a rate limiter to put limits on task executions. -type rateLimiter interface { - start() - acquire() bool - stop() +// RateLimiter is used to put limits on task executions. +type RateLimiter interface { + // Start is used to enable the rate limiter. + // It can be implemented as a noop if not needed. + Start() + + // Acquire() is called before executing a task.Fn function. + // If Acquire() returns true, the task.Fn function will be executed. + // If Acquire() returns false, the task.Fn function won't be executed this time, but Acquire() will be called very soon. + // It works like: + // for { + // blocked := rateLimiter.Acquire() + // if !blocked { + // task.Fn() + // } + // } + // Acquire() should block the caller until execution is allowed. + Acquire() bool + + // Stop is used to disable the rate limiter. + // It can be implemented as a noop if not needed. + Stop() } -// stableRateLimiter uses the token bucket algorithm. +// A StableRateLimiter uses the token bucket algorithm. // the bucket is refilled according to the refill period, no burst is allowed. -type stableRateLimiter struct { +type StableRateLimiter struct { threshold int64 currentThreshold int64 - refillPeroid time.Duration + refillPeriod time.Duration broadcastChannel chan bool quitChannel chan bool } -func newStableRateLimiter(threshold int64, refillPeroid time.Duration) (rateLimiter *stableRateLimiter) { - rateLimiter = &stableRateLimiter{ +// NewStableRateLimiter returns a StableRateLimiter. +func NewStableRateLimiter(threshold int64, refillPeriod time.Duration) (rateLimiter *StableRateLimiter) { + rateLimiter = &StableRateLimiter{ threshold: threshold, currentThreshold: threshold, - refillPeroid: refillPeroid, + refillPeriod: refillPeriod, broadcastChannel: make(chan bool), } return rateLimiter } -func (limiter *stableRateLimiter) start() { +// Start to refill the bucket periodically. +func (limiter *StableRateLimiter) Start() { limiter.quitChannel = make(chan bool) quitChannel := limiter.quitChannel go func() { @@ -46,7 +65,7 @@ func (limiter *stableRateLimiter) start() { return default: atomic.StoreInt64(&limiter.currentThreshold, limiter.threshold) - time.Sleep(limiter.refillPeroid) + time.Sleep(limiter.refillPeriod) close(limiter.broadcastChannel) limiter.broadcastChannel = make(chan bool) } @@ -54,7 +73,8 @@ func (limiter *stableRateLimiter) start() { }() } -func (limiter *stableRateLimiter) acquire() (blocked bool) { +// Acquire a token from the bucket, returns true if the bucket is exhausted. +func (limiter *StableRateLimiter) Acquire() (blocked bool) { permit := atomic.AddInt64(&limiter.currentThreshold, -1) if permit < 0 { blocked = true @@ -66,21 +86,22 @@ func (limiter *stableRateLimiter) acquire() (blocked bool) { return blocked } -func (limiter *stableRateLimiter) stop() { +// Stop the rate limiter. +func (limiter *StableRateLimiter) Stop() { close(limiter.quitChannel) } // ErrParsingRampUpRate is the error returned if the format of rampUpRate is invalid. var ErrParsingRampUpRate = errors.New("ratelimiter: invalid format of rampUpRate, try \"1\" or \"1/1s\"") -// rampUpRateLimiter uses the token bucket algorithm. +// A RampUpRateLimiter uses the token bucket algorithm. // the threshold is updated according to the warm up rate. // the bucket is refilled according to the refill period, no burst is allowed. -type rampUpRateLimiter struct { +type RampUpRateLimiter struct { maxThreshold int64 nextThreshold int64 currentThreshold int64 - refillPeroid time.Duration + refillPeriod time.Duration rampUpRate string rampUpStep int64 rampUpPeroid time.Duration @@ -89,13 +110,15 @@ type rampUpRateLimiter struct { quitChannel chan bool } -func newRampUpRateLimiter(maxThreshold int64, rampUpRate string, refillPeroid time.Duration) (rateLimiter *rampUpRateLimiter, err error) { - rateLimiter = &rampUpRateLimiter{ +// NewRampUpRateLimiter returns a RampUpRateLimiter. +// Valid formats of rampUpRate are "1", "1/1s". +func NewRampUpRateLimiter(maxThreshold int64, rampUpRate string, refillPeriod time.Duration) (rateLimiter *RampUpRateLimiter, err error) { + rateLimiter = &RampUpRateLimiter{ maxThreshold: maxThreshold, nextThreshold: 0, currentThreshold: 0, rampUpRate: rampUpRate, - refillPeroid: refillPeroid, + refillPeriod: refillPeriod, broadcastChannel: make(chan bool), } rateLimiter.rampUpStep, rateLimiter.rampUpPeroid, err = rateLimiter.parseRampUpRate(rateLimiter.rampUpRate) @@ -105,7 +128,7 @@ func newRampUpRateLimiter(maxThreshold int64, rampUpRate string, refillPeroid ti return rateLimiter, nil } -func (limiter *rampUpRateLimiter) parseRampUpRate(rampUpRate string) (rampUpStep int64, rampUpPeroid time.Duration, err error) { +func (limiter *RampUpRateLimiter) parseRampUpRate(rampUpRate string) (rampUpStep int64, rampUpPeroid time.Duration, err error) { if strings.Contains(rampUpRate, "/") { tmp := strings.Split(rampUpRate, "/") if len(tmp) != 2 { @@ -130,7 +153,8 @@ func (limiter *rampUpRateLimiter) parseRampUpRate(rampUpRate string) (rampUpStep return rampUpStep, rampUpPeroid, nil } -func (limiter *rampUpRateLimiter) start() { +// Start to refill the bucket periodically. +func (limiter *RampUpRateLimiter) Start() { limiter.quitChannel = make(chan bool) quitChannel := limiter.quitChannel // bucket updater @@ -141,7 +165,7 @@ func (limiter *rampUpRateLimiter) start() { return default: atomic.StoreInt64(&limiter.currentThreshold, limiter.nextThreshold) - time.Sleep(limiter.refillPeroid) + time.Sleep(limiter.refillPeriod) close(limiter.broadcastChannel) limiter.broadcastChannel = make(chan bool) } @@ -169,7 +193,8 @@ func (limiter *rampUpRateLimiter) start() { }() } -func (limiter *rampUpRateLimiter) acquire() (blocked bool) { +// Acquire a token from the bucket, returns true if the bucket is exhausted. +func (limiter *RampUpRateLimiter) Acquire() (blocked bool) { permit := atomic.AddInt64(&limiter.currentThreshold, -1) if permit < 0 { blocked = true @@ -181,7 +206,8 @@ func (limiter *rampUpRateLimiter) acquire() (blocked bool) { return blocked } -func (limiter *rampUpRateLimiter) stop() { +// Stop the rate limiter. +func (limiter *RampUpRateLimiter) Stop() { limiter.nextThreshold = 0 close(limiter.quitChannel) } diff --git a/ratelimiter_test.go b/ratelimiter_test.go index e95315d..6751d2c 100644 --- a/ratelimiter_test.go +++ b/ratelimiter_test.go @@ -6,31 +6,31 @@ import ( ) func TestStableRateLimiter(t *testing.T) { - rateLimiter := newStableRateLimiter(1, 10*time.Millisecond) - rateLimiter.start() - blocked := rateLimiter.acquire() + rateLimiter := NewStableRateLimiter(1, 10*time.Millisecond) + rateLimiter.Start() + blocked := rateLimiter.Acquire() if blocked { t.Error("Unexpected blocked by rate limiter") } - blocked = rateLimiter.acquire() + blocked = rateLimiter.Acquire() if !blocked { t.Error("Should be blocked") } - rateLimiter.stop() + rateLimiter.Stop() } func TestRampUpRateLimiter(t *testing.T) { - rateLimiter, _ := newRampUpRateLimiter(100, "10/200ms", 100*time.Millisecond) - rateLimiter.start() + rateLimiter, _ := NewRampUpRateLimiter(100, "10/200ms", 100*time.Millisecond) + rateLimiter.Start() time.Sleep(110 * time.Millisecond) for i := 0; i < 10; i++ { - blocked := rateLimiter.acquire() + blocked := rateLimiter.Acquire() if blocked { t.Error("Unexpected blocked by rate limiter") } } - blocked := rateLimiter.acquire() + blocked := rateLimiter.Acquire() if !blocked { t.Error("Should be blocked") } @@ -39,39 +39,39 @@ func TestRampUpRateLimiter(t *testing.T) { // now, the threshold is 20 for i := 0; i < 20; i++ { - blocked := rateLimiter.acquire() + blocked := rateLimiter.Acquire() if blocked { t.Error("Unexpected blocked by rate limiter") } } - blocked = rateLimiter.acquire() + blocked = rateLimiter.Acquire() if !blocked { t.Error("Should be blocked") } - rateLimiter.stop() + rateLimiter.Stop() } func TestParseRampUpRate(t *testing.T) { - rateLimiter := &rampUpRateLimiter{} - rampUpStep, rampUpPeroid, _ := rateLimiter.parseRampUpRate("100") + rateLimiter := &RampUpRateLimiter{} + rampUpStep, rampUpPeriod, _ := rateLimiter.parseRampUpRate("100") if rampUpStep != 100 { t.Error("Wrong rampUpStep, expected: 100, was:", rampUpStep) } - if rampUpPeroid != time.Second { - t.Error("Wrong rampUpPeroid, expected: 1s, was:", rampUpPeroid) + if rampUpPeriod != time.Second { + t.Error("Wrong rampUpPeriod, expected: 1s, was:", rampUpPeriod) } - rampUpStep, rampUpPeroid, _ = rateLimiter.parseRampUpRate("200/10s") + rampUpStep, rampUpPeriod, _ = rateLimiter.parseRampUpRate("200/10s") if rampUpStep != 200 { t.Error("Wrong rampUpStep, expected: 200, was:", rampUpStep) } - if rampUpPeroid != 10*time.Second { - t.Error("Wrong rampUpPeroid, expected: 10s, was:", rampUpPeroid) + if rampUpPeriod != 10*time.Second { + t.Error("Wrong rampUpPeriod, expected: 10s, was:", rampUpPeriod) } } func TestParseInvalidRampUpRate(t *testing.T) { - rateLimiter := &rampUpRateLimiter{} + rateLimiter := &RampUpRateLimiter{} _, _, err := rateLimiter.parseRampUpRate("A/1m") if err == nil || err != ErrParsingRampUpRate { diff --git a/runner.go b/runner.go index d8e2f1a..8f21b71 100644 --- a/runner.go +++ b/runner.go @@ -41,14 +41,12 @@ type runner struct { client client nodeID string hatchType string - rateLimiter rateLimiter + rateLimiter RateLimiter rateLimitEnabled bool stats *requestStats - // cache of current time in second - now int64 } -func newRunner(tasks []*Task, rateLimiter rateLimiter, hatchType string) (r *runner) { +func newRunner(tasks []*Task, rateLimiter RateLimiter, hatchType string) (r *runner) { r = &runner{ tasks: tasks, hatchType: hatchType, @@ -61,10 +59,6 @@ func newRunner(tasks []*Task, rateLimiter rateLimiter, hatchType string) (r *run r.rateLimiter = rateLimiter } - if hatchType != "asap" && hatchType != "smooth" { - log.Fatalf("Wrong hatch-type, expected asap or smooth, was %s\n", hatchType) - } - r.stats = newRequestStats() return r @@ -122,7 +116,7 @@ func (r *runner) spawnGoRoutines(spawnCount int, quit chan bool) { return default: if r.rateLimitEnabled { - blocked := r.rateLimiter.acquire() + blocked := r.rateLimiter.Acquire() if !blocked { r.safeRun(fn) } @@ -170,7 +164,7 @@ func (r *runner) stop() { // those goroutines will exit when r.safeRun returns close(r.stopChannel) if r.rateLimitEnabled { - r.rateLimiter.stop() + r.rateLimiter.Stop() } } @@ -202,7 +196,7 @@ func (r *runner) onHatchMessage(msg *message) { Events.Publish("boomer:hatch", workers, hatchRate) if r.rateLimitEnabled { - r.rateLimiter.start() + r.rateLimiter.Start() } r.startHatching(workers, hatchRate) } @@ -210,18 +204,14 @@ func (r *runner) onHatchMessage(msg *message) { // Runner acts as a state machine, and runs in one goroutine without any lock. func (r *runner) onMessage(msg *message) { - if msg.Type == "quit" { - log.Println("Got quit message from master, shutting down...") - r.state = stateQuitting - Events.Publish("boomer:quit") - os.Exit(0) - } - switch r.state { case stateInit: - if msg.Type == "hatch" { + switch msg.Type { + case "hatch": r.state = stateHatching r.onHatchMessage(msg) + case "quit": + Events.Publish("boomer:quit") } case stateHatching: fallthrough @@ -237,11 +227,20 @@ func (r *runner) onMessage(msg *message) { log.Println("Recv stop message from master, all the goroutines are stopped") r.client.sendChannel() <- newMessage("client_stopped", nil, r.nodeID) r.client.sendChannel() <- newMessage("client_ready", nil, r.nodeID) + case "quit": + r.stop() + log.Println("Recv quit message from master, all the goroutines are stopped") + Events.Publish("boomer:quit") + r.state = stateInit } case stateStopped: - if msg.Type == "hatch" { + switch msg.Type { + case "hatch": r.state = stateHatching r.onHatchMessage(msg) + case "quit": + Events.Publish("boomer:quit") + r.state = stateInit } } } @@ -288,17 +287,5 @@ func (r *runner) getReady() { } }() - go func() { - var ticker = time.NewTicker(time.Second) - for { - select { - case <-ticker.C: - r.now = time.Now().Unix() - case <-r.shutdownSignal: - return - } - } - }() - Events.Subscribe("boomer:quit", r.onQuiting) } diff --git a/runner_test.go b/runner_test.go index 0eea0a1..cbe3178 100644 --- a/runner_test.go +++ b/runner_test.go @@ -7,11 +7,10 @@ import ( ) func TestSafeRun(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") - defaultRunner.safeRun(func() { + runner := newRunner(nil, nil, "asap") + runner.safeRun(func() { panic("Runner will catch this panic") }) - defaultRunner = nil } func TestSpawnGoRoutines(t *testing.T) { @@ -30,7 +29,7 @@ func TestSpawnGoRoutines(t *testing.T) { Name: "TaskB", } tasks := []*Task{taskA, taskB} - rateLimiter := newStableRateLimiter(100, time.Second) + rateLimiter := NewStableRateLimiter(100, time.Second) runner := newRunner(tasks, rateLimiter, "asap") defer runner.close() @@ -327,7 +326,7 @@ func TestGetReady(t *testing.T) { defer server.close() server.start() - rateLimiter := newStableRateLimiter(100, time.Second) + rateLimiter := NewStableRateLimiter(100, time.Second) r := newRunner(nil, rateLimiter, "asap") r.masterHost = masterHost r.masterPort = masterPort diff --git a/stats.go b/stats.go index c18ab59..7fe0580 100644 --- a/stats.go +++ b/stats.go @@ -197,9 +197,7 @@ func (s *statsEntry) log(responseTime int64, contentLength int64) { } func (s *statsEntry) logTimeOfRequest() { - // 'now' is updated by another goroutine - // make a copy to avoid race condition - key := defaultRunner.now + key := time.Now().Unix() _, ok := s.numReqsPerSec[key] if !ok { s.numReqsPerSec[key] = 1 diff --git a/stats_test.go b/stats_test.go index 47b9c17..da851be 100644 --- a/stats_test.go +++ b/stats_test.go @@ -6,7 +6,6 @@ import ( ) func TestLogRequest(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 2, 30) newStats.logRequest("http", "success", 3, 40) @@ -46,21 +45,16 @@ func TestLogRequest(t *testing.T) { if newStats.total.totalContentLength != 130 { t.Error("newStats.total.totalContentLength is wrong, expected: 130, got:", newStats.total.totalContentLength) } - - defaultRunner = nil } func BenchmarkLogRequest(b *testing.B) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() for i := 0; i < b.N; i++ { newStats.logRequest("http", "success", 2, 30) } - defaultRunner = nil } func TestRoundedResponseTime(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 147, 1) newStats.logRequest("http", "success", 3432, 1) @@ -83,12 +77,9 @@ func TestRoundedResponseTime(t *testing.T) { if val, ok := responseTimes[59000]; !ok || val != 1 { t.Error("Rounded response time should be", 59000) } - - defaultRunner = nil } func TestLogError(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logError("http", "failure", "500 error") newStats.logError("http", "failure", "400 error") @@ -121,22 +112,18 @@ func TestLogError(t *testing.T) { t.Error("Error occurences is wrong, expected: 2, got:", err400.occurences) } - defaultRunner = nil } func BenchmarkLogError(b *testing.B) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() for i := 0; i < b.N; i++ { // LogError use md5 to calculate hash keys, it may slow down the only goroutine, // which consumes both requestSuccessChannel and requestFailureChannel. newStats.logError("http", "failure", "500 error") } - defaultRunner = nil } func TestClearAll(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 1, 20) newStats.clearAll() @@ -144,11 +131,9 @@ func TestClearAll(t *testing.T) { if newStats.total.numRequests != 0 { t.Error("After clearAll(), newStats.total.numRequests is wrong, expected: 0, got:", newStats.total.numRequests) } - defaultRunner = nil } func TestClearAllByChannel(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.start() defer newStats.close() @@ -158,11 +143,9 @@ func TestClearAllByChannel(t *testing.T) { if newStats.total.numRequests != 0 { t.Error("After clearAll(), newStats.total.numRequests is wrong, expected: 0, got:", newStats.total.numRequests) } - defaultRunner = nil } func TestSerializeStats(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 1, 20) @@ -185,12 +168,9 @@ func TestSerializeStats(t *testing.T) { if first["num_failures"].(int64) != int64(0) { t.Error("The num_failures is wrong, expected:", 0, "got:", first["num_failures"].(int64)) } - - defaultRunner = nil } func TestSerializeErrors(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logError("http", "failure", "500 error") newStats.logError("http", "failure", "400 error") @@ -214,11 +194,9 @@ func TestSerializeErrors(t *testing.T) { } } } - defaultRunner = nil } func TestCollectReportData(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.logRequest("http", "success", 2, 30) newStats.logError("http", "failure", "500 error") @@ -233,11 +211,9 @@ func TestCollectReportData(t *testing.T) { if _, ok := result["errors"]; !ok { t.Error("Key stats not found") } - defaultRunner = nil } func TestStatsStart(t *testing.T) { - defaultRunner = newRunner(nil, nil, "asap") newStats := newRequestStats() newStats.start() defer newStats.close() @@ -266,6 +242,4 @@ func TestStatsStart(t *testing.T) { } } end: - - defaultRunner = nil } diff --git a/utils.go b/utils.go index 7c0a8bf..ffe567b 100644 --- a/utils.go +++ b/utils.go @@ -4,8 +4,10 @@ import ( "crypto/md5" "fmt" "io" + "log" "math" "os" + "runtime/pprof" "strings" "time" @@ -47,3 +49,44 @@ func getNodeID() (nodeID string) { func Now() int64 { return time.Now().UnixNano() / int64(time.Millisecond) } + +// StartMemoryProfile starts memory profiling and save the results in file. +func StartMemoryProfile(file string, duration time.Duration) { + f, err := os.Create(file) + if err != nil { + log.Fatal(err) + } + + log.Println("Start memory profiling for", duration) + time.AfterFunc(duration, func() { + err = pprof.WriteHeapProfile(f) + if err != nil { + log.Println(err) + return + } + f.Close() + log.Println("Stop memory profiling after", duration) + }) +} + +// StartCPUProfile starts cpu profiling and save the results in file. +func StartCPUProfile(file string, duration time.Duration) { + f, err := os.Create(file) + if err != nil { + log.Fatal(err) + } + + log.Println("Start cpu profiling for", duration) + err = pprof.StartCPUProfile(f) + if err != nil { + log.Println(err) + f.Close() + return + } + + time.AfterFunc(duration, func() { + pprof.StopCPUProfile() + f.Close() + log.Println("Stop CPU profiling after", duration) + }) +} diff --git a/utils_test.go b/utils_test.go index 131081e..17fb998 100644 --- a/utils_test.go +++ b/utils_test.go @@ -4,6 +4,7 @@ import ( "os" "regexp" "testing" + "time" ) func TestRound(t *testing.T) { @@ -59,3 +60,29 @@ func TestNow(t *testing.T) { t.Error("Invalid format of timestamp in milliseconds") } } + +func TestStartMemoryProfile(t *testing.T) { + if _, err := os.Stat("mem.pprof"); os.IsExist(err) { + os.Remove("mem.pprof") + } + StartMemoryProfile("mem.pprof", 2*time.Second) + time.Sleep(2100 * time.Millisecond) + if _, err := os.Stat("mem.pprof"); os.IsNotExist(err) { + t.Error("File mem.pprof is not generated") + } else { + os.Remove("mem.pprof") + } +} + +func TestStartCPUProfile(t *testing.T) { + if _, err := os.Stat("cpu.pprof"); os.IsExist(err) { + os.Remove("cpu.pprof") + } + StartCPUProfile("cpu.pprof", 2*time.Second) + time.Sleep(2100 * time.Millisecond) + if _, err := os.Stat("cpu.pprof"); os.IsNotExist(err) { + t.Error("File cpu.pprof is not generated") + } else { + os.Remove("cpu.pprof") + } +}