From 24a922307e07775f1082fd26c12289c865e906aa Mon Sep 17 00:00:00 2001 From: Marcel Edmund Franke Date: Sun, 9 May 2021 17:51:30 +0200 Subject: [PATCH] fix golint issues --- bus/bus_test.go | 4 +- cmd/xcode/main.go | 12 +- concurrent/runner_test.go | 6 +- debugutil/http_log_round_tripper_test.go | 1 + debugutil/prettysprint_test.go | 5 +- event/hook.go | 10 +- event/hook_test.go | 2 +- internal/ast/ast.go | 12 +- lease/doc_test.go | 2 +- loop/loop.go | 10 +- loop/loop_test.go | 12 +- multierror/doc_test.go | 8 +- promise/promise_test.go | 2 +- retry/retry.go | 12 +- retry/retry_test.go | 8 +- retry/roundtripper.go | 8 ++ retry/roundtripper_test.go | 165 +++++++++-------------- schedule/schedule.go | 4 +- singleton/singleton.go | 3 + worker/worker_test.go | 10 +- xhttp/xhttp.go | 19 ++- xhttp/xhttp_test.go | 23 +++- 22 files changed, 197 insertions(+), 141 deletions(-) diff --git a/bus/bus_test.go b/bus/bus_test.go index 2fb2c52..15137e5 100644 --- a/bus/bus_test.go +++ b/bus/bus_test.go @@ -13,11 +13,13 @@ type msg struct { body string } +var handlerError error = errors.New("handler error") + func TestHandlerReturnsError(t *testing.T) { b := bus.New() err := b.AddHandler(func(m *msg) error { - return errors.New("handler error") + return handlerError }) if err != nil { t.Fatal(err) diff --git a/cmd/xcode/main.go b/cmd/xcode/main.go index dc0a30f..4fc2b0e 100644 --- a/cmd/xcode/main.go +++ b/cmd/xcode/main.go @@ -15,6 +15,7 @@ func main() { log.SetFlags(0) fs := flag.NewFlagSet("xcode", flag.ExitOnError) + var ( in = fs.String("in", "", "input file") out = fs.String("out", "", "output file") @@ -22,8 +23,13 @@ func main() { typ = fs.String("type", "", "type") mode = fs.String("mode", "", "activate mode") ) + fs.Usage = usageFor(fs, "xcode [flags]") - fs.Parse(os.Args[1:]) + + err := fs.Parse(os.Args[1:]) + if err != nil { + log.Fatal(err) + } if *in == "" { log.Fatal("input file is missing") @@ -61,11 +67,14 @@ func main() { func usageFor(fs *flag.FlagSet, short string) func() { return func() { + fmt.Fprintf(os.Stdout, "USAGE\n") fmt.Fprintf(os.Stdout, " %s\n", short) fmt.Fprintf(os.Stdout, "\n") fmt.Fprintf(os.Stdout, "FLAGS\n") + tw := tabwriter.NewWriter(os.Stdout, 0, 2, 2, ' ', 0) + fs.VisitAll(func(f *flag.Flag) { if f.Name == "debug" { return @@ -76,6 +85,7 @@ func usageFor(fs *flag.FlagSet, short string) func() { } fmt.Fprintf(tw, " -%s %s\t%s\n", f.Name, f.DefValue, f.Usage) }) + tw.Flush() } } diff --git a/concurrent/runner_test.go b/concurrent/runner_test.go index 17f04dd..8b02ded 100644 --- a/concurrent/runner_test.go +++ b/concurrent/runner_test.go @@ -8,6 +8,8 @@ import ( "github.com/donutloop/toolkit/concurrent" ) +var StubErr error = errors.New("stub error") + func TestRun(t *testing.T) { counter := int32(0) errs := concurrent.Run( @@ -38,7 +40,7 @@ func TestRunFail(t *testing.T) { counter := int32(0) errs := concurrent.Run( func() error { - return errors.New("stub error") + return StubErr }, func() error { panic("check isolation of goroutine") @@ -64,7 +66,7 @@ func BenchmarkRun(b *testing.B) { for n := 0; n < b.N; n++ { concurrent.Run( func() error { - return errors.New("stub error") + return StubErr }, func() error { panic("check isolation of goroutine") diff --git a/debugutil/http_log_round_tripper_test.go b/debugutil/http_log_round_tripper_test.go index 64c0f10..2775018 100644 --- a/debugutil/http_log_round_tripper_test.go +++ b/debugutil/http_log_round_tripper_test.go @@ -16,6 +16,7 @@ type logger struct{} func (l logger) Errorf(format string, v ...interface{}) { log.Println(fmt.Sprintf(format, v...)) } + func (l logger) Infof(format string, v ...interface{}) { log.Println(fmt.Sprintf(format, v...)) } diff --git a/debugutil/prettysprint_test.go b/debugutil/prettysprint_test.go index 32df21c..ecee8fb 100644 --- a/debugutil/prettysprint_test.go +++ b/debugutil/prettysprint_test.go @@ -5,12 +5,11 @@ package debugutil_test import ( - "testing" - "github.com/donutloop/toolkit/debugutil" + "testing" ) -func Test(t *testing.T) { +func TestDebugger(t *testing.T){ strings := "dummy" diff --git a/event/hook.go b/event/hook.go index 62e6c43..a41e14c 100644 --- a/event/hook.go +++ b/event/hook.go @@ -2,6 +2,7 @@ package event import ( "fmt" + "runtime/debug" "sync" ) @@ -43,10 +44,17 @@ func (h *Hooks) Fire() []error { func hookWrapper(wg *sync.WaitGroup, hook func(), errc chan error) { defer func() { if v := recover(); v != nil { - errc <- fmt.Errorf("hook is panicked (%v)", v) + errc <- &RecoverError{Err: v, Stack: debug.Stack()} } wg.Done() }() hook() } + +type RecoverError struct { + Err interface{} + Stack []byte +} + +func (e *RecoverError) Error() string { return fmt.Sprintf("Do panicked: %v", e.Err) } diff --git a/event/hook_test.go b/event/hook_test.go index a8b8ad2..8f3503b 100644 --- a/event/hook_test.go +++ b/event/hook_test.go @@ -48,7 +48,7 @@ func TestHooksPanic(t *testing.T) { t.Fatalf("error count is bad (%d)", len(errs)) } - expectedMessage := "hook is panicked (check isolation of goroutine)" + expectedMessage := "Do panicked: check isolation of goroutine" if errs[0].Error() != expectedMessage { t.Fatalf(`unexpected error message (actual: "%s", expected: "%s")`, errs[0].Error(), expectedMessage) } diff --git a/internal/ast/ast.go b/internal/ast/ast.go index eacd289..66e688f 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -37,7 +37,7 @@ func ChangeType(typeName string, newType string, debugMode string) func(file *as for i := 0; i < len(x.Args); i++ { v, ok := x.Args[i].(*ast.Ident) if ok { - if strings.ToLower(typeName) == strings.ToLower(v.Name) { + if strings.EqualFold(typeName, v.Name) { x.Args[i] = &ast.Ident{Name: fmt.Sprintf("%s.(%s)", v.Name, newType)} } } @@ -73,8 +73,16 @@ func ModifyAst(dest []byte, fns ...func(*ast.File) *ast.File) ([]byte, error) { var buf bytes.Buffer if err := format.Node(&buf, destFset, destF); err != nil { - return nil, fmt.Errorf("couldn't format package code (%v)", err) + return nil, &BadFormattedCode{Err: err} } return buf.Bytes(), nil } + +type BadFormattedCode struct { + Err error +} + +func (e BadFormattedCode) Error() string { + return fmt.Sprintf("couldn't format package code (%v)", e.Err) +} diff --git a/lease/doc_test.go b/lease/doc_test.go index f5f7c10..5e69663 100644 --- a/lease/doc_test.go +++ b/lease/doc_test.go @@ -10,7 +10,7 @@ import ( func ExampleLeaser_Lease() { leaser := lease.NewLeaser() - leaser.Lease("cleanup-cache", time.Duration(1*time.Second), func() { + leaser.Lease("cleanup-cache", 1*time.Second, func() { fmt.Println("cleaned up cache") }) diff --git a/loop/loop.go b/loop/loop.go index b0cc3c1..3cd75a3 100644 --- a/loop/loop.go +++ b/loop/loop.go @@ -6,6 +6,7 @@ package loop import ( "fmt" + "runtime/debug" "time" ) @@ -41,7 +42,7 @@ func (l *looper) doLoop() { defer ticker.Stop() defer func() { if v := recover(); v != nil { - l.err <- fmt.Errorf("event is panicked (%v)", v) + l.err <- &RecoverError{Err: v, Stack: debug.Stack()} } }() @@ -57,3 +58,10 @@ func (l *looper) doLoop() { } } } + +type RecoverError struct { + Err interface{} + Stack []byte +} + +func (e *RecoverError) Error() string { return fmt.Sprintf("Do panicked: %v", e.Err) } diff --git a/loop/loop_test.go b/loop/loop_test.go index 138ae8b..4eb7c77 100644 --- a/loop/loop_test.go +++ b/loop/loop_test.go @@ -13,6 +13,8 @@ import ( "github.com/donutloop/toolkit/loop" ) +var StubErr error = errors.New("stub error") + func TestLoop(t *testing.T) { var counter int l := loop.NewLooper(1*time.Millisecond, func() error { @@ -28,22 +30,24 @@ func TestLoop(t *testing.T) { } } +var GoroutineError error = fmt.Errorf("check isolation of goroutine") + func TestLoopFail(t *testing.T) { l := loop.NewLooper(1*time.Millisecond, func() error { - panic(fmt.Errorf("check isolation of goroutine")) + panic(GoroutineError) }) err := <-l.Error() - if err.Error() != "event is panicked (check isolation of goroutine)" { + if err.Error() != "Do panicked: check isolation of goroutine" { t.Fatal(err) } l = loop.NewLooper(1*time.Millisecond, func() error { - return errors.New("stub error") + return StubErr }) err = <-l.Error() - if err.Error() != "stub error" { + if !errors.Is(err, StubErr) { t.Fatal(err) } } diff --git a/multierror/doc_test.go b/multierror/doc_test.go index ffa85cb..296933d 100644 --- a/multierror/doc_test.go +++ b/multierror/doc_test.go @@ -3,16 +3,14 @@ package multierror_test import ( "fmt" - "errors" - "github.com/donutloop/toolkit/multierror" ) func Example() { errs := []error{ - errors.New("error connect to db failed"), - errors.New("error marschaling json"), + connectionError, + marshalError, } fmt.Println(multierror.New(errs...)) - // Output: multiple errors: error connect to db failed; error marschaling json + // Output: multiple errors: error connect to db failed; error marshal json } diff --git a/promise/promise_test.go b/promise/promise_test.go index 21cdc7b..07d557f 100644 --- a/promise/promise_test.go +++ b/promise/promise_test.go @@ -29,7 +29,7 @@ func TestDoPanic(t *testing.T) { t.Fatal("unexpected nil error") } - if !strings.Contains(err.Error(), "Do panicked") { + if !strings.Contains(err.Error(), "Do panicked ") { t.Fatalf(`unexpected error message (actual: "%s", expected: "promise is panicked (*)")`, err.Error()) } } diff --git a/retry/retry.go b/retry/retry.go index d011ad0..a2df611 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -23,18 +23,18 @@ type Retrier interface { type RetryableDo func() (bool, error) -func NewRetrier(InitialIntervalInSeconds, maxIntervalInSeconds float64, tries uint, strategy Strategy) Retrier { +func NewRetrier(initialIntervalInSeconds, maxIntervalInSeconds float64, tries uint, strategy Strategy) Retrier { if strategy == nil { panic("strategy is missing") } - if InitialIntervalInSeconds > maxIntervalInSeconds { - panic(fmt.Sprintf("initial interval is greater than max (initial: %f, max: %f)", InitialIntervalInSeconds, maxIntervalInSeconds)) + if initialIntervalInSeconds > maxIntervalInSeconds { + panic(fmt.Sprintf("initial interval is greater than max (initial: %f, max: %f)", initialIntervalInSeconds, maxIntervalInSeconds)) } return &retrier{ - InitialIntervalInSeconds: InitialIntervalInSeconds, + InitialIntervalInSeconds: initialIntervalInSeconds, maxIntervalInSeconds: maxIntervalInSeconds, strategy: strategy, tries: tries, @@ -54,8 +54,11 @@ func (r *retrier) Retry(ctx context.Context, do RetryableDo) error { } var err error + var done bool + for i := uint(0); !done && i < r.tries; i++ { + done, err = do() if ctx.Err() != nil { @@ -74,6 +77,7 @@ func (r *retrier) Retry(ctx context.Context, do RetryableDo) error { if !done { return new(ExhaustedError) } + return nil } diff --git a/retry/retry_test.go b/retry/retry_test.go index f4b5364..6e720df 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -12,11 +12,14 @@ import ( "github.com/donutloop/toolkit/retry" ) +var StubErr error = errors.New("stub error") + func TestRetrierRetryContextDeadlineFail(t *testing.T) { r := retry.NewRetrier(0.125, 0.25, 2, new(retry.Exp)) ctx, cancel := context.WithCancel(context.Background()) cancel() + err := r.Retry(ctx, func() (bool, error) { return true, nil }) @@ -45,15 +48,14 @@ func TestRetrierRetry(t *testing.T) { func TestRetrierRetryTriggerError(t *testing.T) { r := retry.NewRetrier(0.125, 0.25, 2, new(retry.Exp)) err := r.Retry(context.Background(), func() (bool, error) { - return false, errors.New("stub error") + return false, StubErr }) if err == nil { t.Fatal("unexpected nil error") } - expectedErrorMessage := "stub error" - if err.Error() != expectedErrorMessage { + if !errors.Is(err, StubErr) { t.Fatal(err) } } diff --git a/retry/roundtripper.go b/retry/roundtripper.go index 846841a..e3e702a 100644 --- a/retry/roundtripper.go +++ b/retry/roundtripper.go @@ -13,7 +13,9 @@ type RoundTripper struct { // NewRoundTripper is constructing a new retry RoundTripper with given default values. func NewRoundTripper(next http.RoundTripper, maxInterval, initialInterval float64, tries uint, blacklistStatusCodes []int, strategy Strategy) *RoundTripper { + retrier := NewRetrier(initialInterval, maxInterval, tries, strategy) + return &RoundTripper{ retrier: retrier, next: next, @@ -25,7 +27,9 @@ func NewRoundTripper(next http.RoundTripper, maxInterval, initialInterval float6 // if rt.next.RoundTrip(req) is return an error then it will abort the process retrying a request. func (rt *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { var resp *http.Response + err := rt.retrier.Retry(context.Background(), func() (b bool, e error) { + var err error resp, err = rt.next.RoundTrip(req) if err != nil { @@ -52,13 +56,17 @@ func (rt *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // isStatusCode iterates over list of black listed status code that it could abort the process of retrying a request func (rt *RoundTripper) isStatusCode(statusCode int) bool { + if rt.blacklistStatusCodes == nil { return false } + for _, sc := range rt.blacklistStatusCodes { + if statusCode == sc { return true } } + return false } diff --git a/retry/roundtripper_test.go b/retry/roundtripper_test.go index f0e97d0..739c46f 100644 --- a/retry/roundtripper_test.go +++ b/retry/roundtripper_test.go @@ -12,113 +12,76 @@ import ( "github.com/donutloop/toolkit/retry" ) -func TestRoundTripper_InternalServer(t *testing.T) { - - var counter int32 - testserver := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - atomic.AddInt32(&counter, 1) - t.Log("hit endpoint") - w.WriteHeader(http.StatusInternalServerError) - })) - - retryRoundTripper := retry.NewRoundTripper(http.DefaultTransport, .50, .15, 3, nil, new(retry.Exp)) - httpClient := new(http.Client) - httpClient.Transport = retryRoundTripper - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testserver.URL, nil) - if err != nil { - t.Fatal(err) - } - - resp, err := httpClient.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("response is bad, got=%v", resp.StatusCode) - } - - if counter != 3 { - t.Errorf("counter is bad, got=%v, want=%v", counter, 3) +func TestResponseCodes(t *testing.T) { + tests := []struct { + name string + responseCode int + blacklisted []int + counter uint + }{ + { + name: "StatusCode", + responseCode: http.StatusOK, + counter: 1, + }, + { + name: "StatusCode", + responseCode: http.StatusInternalServerError, + counter: 3, + }, + { + name: "blacklisted", + responseCode: http.StatusInternalServerError, + blacklisted: []int{http.StatusInternalServerError}, + counter: 1, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + var counter int32 + + testsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + atomic.AddInt32(&counter, 1) + t.Log("hit endpoint") + w.WriteHeader(test.responseCode) + })) + + retryRoundTripper := retry.NewRoundTripper(http.DefaultTransport, .50, .15, test.counter, test.blacklisted, new(retry.Exp)) + httpClient := new(http.Client) + httpClient.Transport = retryRoundTripper + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testsServer.URL, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := httpClient.Do(req) + if err != nil { + t.Fatal(err) + } + + defer resp.Body.Close() + + if resp.StatusCode != test.responseCode { + t.Errorf("response is bad, got=%v", resp.StatusCode) + } + + if counter != int32(test.counter) { + t.Errorf("counter is bad, got=%v, want=%v", counter, int32(test.counter)) + } + }) } } -func TestRoundTripper_InternalServerBlacklisted(t *testing.T) { - - var counter int32 - testserver := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - atomic.AddInt32(&counter, 1) - t.Log("hit endpoint") - w.WriteHeader(http.StatusInternalServerError) - })) - - retryRoundTripper := retry.NewRoundTripper(http.DefaultTransport, .50, .15, 3, []int{http.StatusInternalServerError}, new(retry.Exp)) - httpClient := new(http.Client) - httpClient.Transport = retryRoundTripper - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testserver.URL, nil) - if err != nil { - t.Fatal(err) - } - - resp, err := httpClient.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("response is bad, got=%v", resp.StatusCode) - } - - if counter != 1 { - t.Errorf("counter is bad, got=%v, want=%v", counter, 1) - } -} - -func TestRoundTripper_StatusOk(t *testing.T) { - - var counter int32 - testserver := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - atomic.AddInt32(&counter, 1) - t.Log("hit endpoint") - w.WriteHeader(http.StatusOK) - })) - - retryRoundTripper := retry.NewRoundTripper(http.DefaultTransport, .50, .15, 3, nil, new(retry.Exp)) - httpClient := new(http.Client) - httpClient.Transport = retryRoundTripper - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testserver.URL, nil) - if err != nil { - t.Fatal(err) - } - - resp, err := httpClient.Do(req) - if err != nil { - t.Fatal(err) - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Errorf("response is bad, got=%v", resp.StatusCode) - } - - if counter != 1 { - t.Errorf("counter is bad, got=%v, want=%v", counter, 1) - } -} - -func TestRoundTripper_JsonStatusOk(t *testing.T) { +func TestRoundTripper_Json(t *testing.T) { json := `{"hello":"world"}` var counter int32 - testserver := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + + testsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { atomic.AddInt32(&counter, 1) t.Log("hit endpoint") @@ -152,7 +115,7 @@ func TestRoundTripper_JsonStatusOk(t *testing.T) { httpClient := new(http.Client) httpClient.Transport = retryRoundTripper - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testserver.URL, bytes.NewBuffer([]byte(json))) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testsServer.URL, bytes.NewBuffer([]byte(json))) if err != nil { t.Fatal(err) } @@ -162,6 +125,8 @@ func TestRoundTripper_JsonStatusOk(t *testing.T) { t.Fatal(err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { t.Fatalf("response is bad, got=%v", resp.StatusCode) } diff --git a/schedule/schedule.go b/schedule/schedule.go index 6ba87d0..2b4f13e 100644 --- a/schedule/schedule.go +++ b/schedule/schedule.go @@ -49,13 +49,15 @@ func defaultPanicHandler(stack DebugStack) { log.Println(string(stack)) } +var StoppedScheduler error = errors.New("schedule: schedule to stopped scheduler") + // Schedule schedules a job that will be ran in FIFO order sequentially. func (f *Fifo) Schedule(j Job) error { f.mu.Lock() defer f.mu.Unlock() if f.cancel == nil { - return errors.New("schedule: schedule to stopped scheduler") + return StoppedScheduler } if len(f.pendings) == 0 { diff --git a/singleton/singleton.go b/singleton/singleton.go index cbaca61..83da2d9 100644 --- a/singleton/singleton.go +++ b/singleton/singleton.go @@ -43,11 +43,14 @@ func (s *singleton) Get() (interface{}, error) { s.m.Lock() defer s.m.Unlock() if s.done == 0 { + var err error + s.object, err = s.Constructor() if err != nil { return nil, err } + defer atomic.StoreUint32(&s.done, 1) } diff --git a/worker/worker_test.go b/worker/worker_test.go index eef8c25..7954478 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -9,6 +9,14 @@ import ( "github.com/donutloop/toolkit/worker" ) +type BadValueError struct { + value interface{} +} + +func (v *BadValueError) Error() string { + return fmt.Sprintf("value is not of descired type got=%v,%#v", v.value, v.value) +} + func TestWorker(t *testing.T) { contains := func(ls []string, s string) bool { @@ -24,7 +32,7 @@ func TestWorker(t *testing.T) { workerHandler := func(parameter interface{}) (interface{}, error) { v, ok := parameter.(string) if !ok { - return false, fmt.Errorf("value is not a string got=%v", parameter) + return false, &BadValueError{value: parameter} } if !contains([]string{"hello", "golang", "world"}, v) { diff --git a/xhttp/xhttp.go b/xhttp/xhttp.go index f2b3dbe..46b9eb5 100644 --- a/xhttp/xhttp.go +++ b/xhttp/xhttp.go @@ -7,24 +7,37 @@ import ( type Middleware func(m http.RoundTripper) http.RoundTripper +var ClientNilError error = errors.New("client is nil") +var MiddlewaresNilError error = errors.New("middlewares is nil") +var MiddlewareNilError error = errors.New("middleware is nil") + // Use is wrapping up a RoundTripper with a set of middleware. func Use(client *http.Client, middlewares ...Middleware) *http.Client { + if client == nil { - panic(errors.New("client is nil")) + panic(ClientNilError) } + if len(middlewares) == 0 { - panic(errors.New("middlewares is nil")) + panic(MiddlewaresNilError) } + if client.Transport == nil { client.Transport = http.DefaultTransport } + current := client.Transport + for _, middleware := range middlewares { + if middleware == nil { - panic(errors.New("middleware is nil")) + panic(MiddlewareNilError) } + current = middleware(current) } + client.Transport = current + return client } diff --git a/xhttp/xhttp_test.go b/xhttp/xhttp_test.go index 8291644..c065098 100644 --- a/xhttp/xhttp_test.go +++ b/xhttp/xhttp_test.go @@ -1,6 +1,7 @@ package xhttp_test import ( + "errors" "log" "net/http" "net/http/httptest" @@ -16,19 +17,22 @@ type TestMiddleware struct { } func (m *TestMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { - m.Log("hitted middleware ", m.ID) + m.Log("hit middleware ", m.ID) + resp, err := m.roundtripper.RoundTrip(req) if err != nil { return resp, nil } + return resp, nil } -func TestInjectMiddlware(t *testing.T) { +func TestInjectMiddleware(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { - log.Println("hitted handler") + log.Println("hit handler") } + s := httptest.NewServer(http.HandlerFunc(handler)) defer s.Close() @@ -47,11 +51,14 @@ func TestInjectMiddlware(t *testing.T) { httpClient := new(http.Client) httpClient = xhttp.Use(httpClient, m1, m2) httpClient = xhttp.Use(httpClient, m3) + resp, err := httpClient.Get(s.URL) if err != nil { log.Fatal(err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { log.Fatal(err) } @@ -61,7 +68,8 @@ func TestPanicNilClient(t *testing.T) { defer func() { v := recover() err := v.(error) - if err.Error() != "client is nil" { + + if errors.Is(err, xhttp.ClientNilError) { t.Errorf("error message is bad (%v)", v) } }() @@ -73,7 +81,8 @@ func TestPanicNilMiddleware(t *testing.T) { defer func() { v := recover() err := v.(error) - if err.Error() != "middleware is nil" { + + if !errors.Is(err, xhttp.MiddlewareNilError) { t.Errorf("error message is bad (%v)", v) } }() @@ -83,9 +92,11 @@ func TestPanicNilMiddleware(t *testing.T) { func TestPanicNilMiddlewares(t *testing.T) { defer func() { + v := recover() err := v.(error) - if err.Error() != "middlewares is nil" { + + if !errors.Is(err, xhttp.MiddlewaresNilError) { t.Errorf("error message is bad (%v)", v) } }()