Skip to content

Commit

Permalink
Added Close() method
Browse files Browse the repository at this point in the history
Expanded on some tests (and filled in missing tests)
  • Loading branch information
dselans committed Oct 11, 2020
1 parent 815c801 commit ce57e7a
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 3 deletions.
28 changes: 28 additions & 0 deletions rabbit.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type IRabbit interface {
ConsumeOnce(ctx context.Context, runFunc func(msg amqp.Delivery) error) error
Publish(ctx context.Context, routingKey string, payload []byte) error
Stop() error
Close() error
}

// Rabbit struct that is instantiated via `New()`. You should not instantiate
Expand Down Expand Up @@ -192,6 +193,20 @@ func ValidateOptions(opts *Options) error {
opts.RetryReconnectSec = DefaultRetryReconnectSec
}

validModes := []Mode{Both, Producer, Consumer}

var found bool

for _, validMode := range validModes {
if validMode == opts.Mode {
found = true
}
}

if !found {
return fmt.Errorf("invalid mode '%d'", opts.Mode)
}

return nil
}

Expand Down Expand Up @@ -342,6 +357,19 @@ func (r *Rabbit) Stop() error {
return nil
}

// Close stops any active Consume and closes the amqp connection (and channels using the conn)
//
// You should re-instantiate the rabbit lib once this is called.
func (r *Rabbit) Close() error {
r.cancel()

if err := r.Conn.Close(); err != nil {
return fmt.Errorf("unable to close amqp connection: %s", err)
}

return nil
}

func (r *Rabbit) watchNotifyClose() {
// TODO: Use a looper here
for {
Expand Down
149 changes: 146 additions & 3 deletions rabbit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
package rabbit

import (
"bytes"
"context"
"io"
"os"
"time"

. "github.com/onsi/ginkgo"
Expand Down Expand Up @@ -49,7 +52,12 @@ var _ = Describe("Rabbit", func() {
Expect(r).ToNot(BeNil())
})

It("should error with bad options", func() {
It("by default, uses Both mode", func() {
Expect(opts.Mode).To(Equal(Both))
Expect(r.Options.Mode).To(Equal(Both))
})

It("should error with missing options", func() {
r, err := New(nil)

Expect(err).ToNot(BeNil())
Expand Down Expand Up @@ -130,6 +138,31 @@ var _ = Describe("Rabbit", func() {
errChan = make(chan *ConsumeError, 1)
)

When("attempting to consume messages in producer mode", func() {
It("Consume should not block and immediately return", func() {
opts.Mode = Producer
ra, err := New(opts)

Expect(err).ToNot(HaveOccurred())
Expect(ra).ToNot(BeNil())

var exit bool

go func() {
r.Consume(nil, nil, func(m amqp.Delivery) error {
return nil
})

exit = true
}()

// Give the goroutine a little to start up
time.Sleep(50 * time.Millisecond)

Expect(exit).To(BeTrue())
})
})

When("consuming messages with a context", func() {
It("run function is executed with inbound message", func() {
receivedMessages := make([]amqp.Delivery, 0)
Expand Down Expand Up @@ -186,8 +219,8 @@ var _ = Describe("Rabbit", func() {

messages := generateRandomStrings(20)

// Publish 5 messages -> cancel -> publish remainder of messages ->
// verify runfunc was hit only 5 times
// Publish 10 messages -> cancel -> publish remainder of messages ->
// verify runfunc was hit only 10 times
publishErr1 := publishMessages(ch, opts, messages[0:10])
Expect(publishErr1).ToNot(HaveOccurred())

Expand Down Expand Up @@ -287,6 +320,21 @@ var _ = Describe("Rabbit", func() {
})

Describe("ConsumeOnce", func() {
When("Mode is Producer", func() {
It("will return an error", func() {
opts.Mode = Producer
ra, err := New(opts)

Expect(err).ToNot(HaveOccurred())
Expect(ra).ToNot(BeNil())

err = ra.ConsumeOnce(nil, func(m amqp.Delivery) error { return nil })

Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("library is configured in Producer mode"))
})
})

When("passed context is nil", func() {
It("will continue to work", func() {
var receivedMessage string
Expand Down Expand Up @@ -420,6 +468,21 @@ var _ = Describe("Rabbit", func() {

Expect(receivedMessage).To(Equal(testMessage))
})

When("Mode is Consumer", func() {
It("should return an error", func() {
opts.Mode = Consumer
ra, err := New(opts)

Expect(err).ToNot(HaveOccurred())
Expect(ra).ToNot(BeNil())

err = ra.Publish(nil, "messages", []byte("test"))

Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("library is configured in Consumer mode"))
})
})
})

When("producer server channel is nil", func() {
Expand Down Expand Up @@ -481,6 +544,55 @@ var _ = Describe("Rabbit", func() {
})
})

Describe("Close", func() {
When("called after instantiating new rabbit", func() {
It("does not error", func() {
err := r.Close()
Expect(err).ToNot(HaveOccurred())
})
})

When("called before Consume", func() {
It("should cause Consume to immediately return", func() {
err := r.Close()
Expect(err).ToNot(HaveOccurred())

// This shouldn't block because internal ctx func should have been called
r.Consume(nil, nil, func(m amqp.Delivery) error {
return nil
})

Expect(true).To(BeTrue())
})
})

When("called before ConsumeOnce", func() {
It("ConsumeOnce should timeout", func() {
err := r.Close()
Expect(err).ToNot(HaveOccurred())

// This shouldn't block because internal ctx func should have been called
err = r.ConsumeOnce(nil, func(m amqp.Delivery) error {
return nil
})

Expect(err).ToNot(HaveOccurred())
})
})

When("called before Publish", func() {
It("Publish should error", func() {
err := r.Close()
Expect(err).ToNot(HaveOccurred())

err = r.Publish(nil, "messages", []byte("testing"))

Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("channel/connection is not open"))
})
})
})

Describe("validateOptions", func() {
Context("validation combinations", func() {
BeforeEach(func() {
Expand All @@ -492,6 +604,13 @@ var _ = Describe("Rabbit", func() {
Expect(err).To(HaveOccurred())
})

It("should error on invalid mode", func() {
opts.Mode = 15
err := ValidateOptions(opts)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("invalid mode"))
})

It("errors when URL is unset", func() {
opts.URL = ""

Expand Down Expand Up @@ -634,3 +753,27 @@ func receiveMessage(ch *amqp.Channel, opts *Options) ([]byte, error) {
return nil, errors.New("timed out")
}
}

func startCapture(outC chan string) (*os.File, *os.File) {
old := os.Stdout
rf, wf, err := os.Pipe()
Expect(err).ToNot(HaveOccurred())

os.Stdout = wf

// copy the output in a separate goroutine so printing can't block indefinitely
go func() {
var buf bytes.Buffer
io.Copy(&buf, rf)
outC <- buf.String()
}()

return wf, old
}

func endCapture(wf, oldStdout *os.File, outC chan string) string {
wf.Close()
os.Stdout = oldStdout
out := <-outC
return out
}

0 comments on commit ce57e7a

Please sign in to comment.