Skip to content

Commit

Permalink
Allow default namespace on client level (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
dim authored Sep 2, 2022
1 parent 7673fd0 commit 6e5f953
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 23 deletions.
50 changes: 30 additions & 20 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@ type Client struct {
stmt struct {
push, pushWithID, get, claim, shift, list, update, done *sql.Stmt
}
opt *scopeOptions
ownDB bool
}

// Connect connects to a PG instance using a URL.
// Example:
// postgres://user:secret@test.host:5432/mydb?sslmode=verify-ca
func Connect(ctx context.Context, url string) (*Client, error) {
func Connect(ctx context.Context, url string, opts ...ScopeOption) (*Client, error) {
db, err := sql.Open("postgres", url)
if err != nil {
return nil, err
}

client, err := Wrap(ctx, db)
client, err := Wrap(ctx, db, opts...)
if err != nil {
_ = db.Close()
return nil, err
Expand All @@ -40,12 +41,18 @@ func Connect(ctx context.Context, url string) (*Client, error) {

// Wrap wraps an existing database/sql.DB instance. Please note that calling
// Close() will not close the underlying connection.
func Wrap(ctx context.Context, db *sql.DB) (*Client, error) {
func Wrap(ctx context.Context, db *sql.DB, opts ...ScopeOption) (*Client, error) {
opt := &scopeOptions{}
opt.set(opts...)
if err := opt.validate(); err != nil {
return nil, err
}

if err := validateConn(ctx, db); err != nil {
return nil, err
}

c := &Client{db: db}
c := &Client{db: db, opt: opt}
if err := c.prepareStmt(ctx); err != nil {
_ = c.Close()
return nil, err
Expand All @@ -56,9 +63,9 @@ func Wrap(ctx context.Context, db *sql.DB) (*Client, error) {
// Truncate truncates the queue and deletes all tasks. Intended for testing,
// please use with care.
func (c *Client) Truncate(ctx context.Context, opts ...ScopeOption) error {
opt := new(scopeOptions)
opt.Set(opts...)
if err := opt.Namespace.validate(); err != nil {
opt := &scopeOptions{Namespace: c.opt.Namespace}
opt.set(opts...)
if err := opt.validate(); err != nil {
return err
}

Expand All @@ -70,9 +77,9 @@ func (c *Client) Truncate(ctx context.Context, opts ...ScopeOption) error {
func (c *Client) Len(ctx context.Context, opts ...ScopeOption) (int64, error) {
var cnt int64

opt := new(scopeOptions)
opt.Set(opts...)
if err := opt.Namespace.validate(); err != nil {
opt := &scopeOptions{Namespace: c.opt.Namespace}
opt.set(opts...)
if err := opt.validate(); err != nil {
return cnt, err
}

Expand All @@ -89,9 +96,9 @@ func (c *Client) Len(ctx context.Context, opts ...ScopeOption) (int64, error) {
func (c *Client) MinCreatedAt(ctx context.Context, opts ...ScopeOption) (time.Time, error) {
var ts pq.NullTime

opt := new(scopeOptions)
opt.Set(opts...)
if err := opt.Namespace.validate(); err != nil {
opt := &scopeOptions{Namespace: c.opt.Namespace}
opt.set(opts...)
if err := opt.validate(); err != nil {
return ts.Time, err
}

Expand All @@ -111,6 +118,9 @@ func (c *Client) Push(ctx context.Context, task *Task) error {
return err
}

if task.Namespace == "" && c.opt.Namespace != "" {
task.Namespace = string(c.opt.Namespace)
}
if len(task.Payload) == 0 {
task.Payload = json.RawMessage{'{', '}'}
}
Expand Down Expand Up @@ -170,9 +180,9 @@ func (c *Client) Claim(ctx context.Context, id uuid.UUID) (*Claim, error) {
// Shift locks and returns the task with the highest priority. It may return
// ErrNoTask.
func (c *Client) Shift(ctx context.Context, opts ...ScopeOption) (*Claim, error) {
opt := new(scopeOptions)
opt.Set(opts...)
if err := opt.Namespace.validate(); err != nil {
opt := &scopeOptions{Namespace: c.opt.Namespace}
opt.set(opts...)
if err := opt.validate(); err != nil {
return nil, err
}

Expand All @@ -198,12 +208,12 @@ func (c *Client) Shift(ctx context.Context, opts ...ScopeOption) (*Claim, error)

// List lists all tasks in the queue.
func (c *Client) List(ctx context.Context, opts ...ListOption) ([]*TaskDetails, error) {
opt := new(listOptions)
opt.Set(opts...)
if err := opt.Namespace.validate(); err != nil {
opt := &listOptions{Namespace: c.opt.Namespace}
opt.set(opts...)
if err := opt.validate(); err != nil {
return nil, err
}
limit := opt.GetLimit()
limit := opt.getLimit()

rows, err := c.stmt.list.QueryContext(ctx, opt.Namespace, limit, opt.Offset)
if err != nil {
Expand Down
14 changes: 11 additions & 3 deletions pgpq.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,23 @@ type listOptions struct {
Namespace namespace
}

func (o *listOptions) GetLimit() int64 {
func (o *listOptions) getLimit() int64 {
if o.Limit == 0 {
return 100
}
return o.Limit
}

func (o *listOptions) Set(opts ...ListOption) {
func (o *listOptions) set(opts ...ListOption) {
for _, opt := range opts {
opt.applyListOption(o)
}
}

func (o *listOptions) validate() error {
return o.Namespace.validate()
}

// ListOption can be applied when listing tasks.
type ListOption interface {
applyListOption(*listOptions)
Expand All @@ -62,12 +66,16 @@ type scopeOptions struct {
Namespace namespace
}

func (o *scopeOptions) Set(opts ...ScopeOption) {
func (o *scopeOptions) set(opts ...ScopeOption) {
for _, opt := range opts {
opt.applyScopeOption(o)
}
}

func (o *scopeOptions) validate() error {
return o.Namespace.validate()
}

// ScopeOption can be applied when scoping results.
type ScopeOption interface {
applyScopeOption(*scopeOptions)
Expand Down

0 comments on commit 6e5f953

Please sign in to comment.