From 074632bbf5e4b955703f17c37171c1a21091f355 Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Mon, 23 Mar 2026 14:07:27 +0100 Subject: [PATCH 1/7] Added subscribe mechanics --- bulk.go | 5 +++ conn.go | 11 ++++++ listener.go | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++- pool.go | 28 +++++++++++++++- pool_test.go | 90 +++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 227 insertions(+), 2 deletions(-) diff --git a/bulk.go b/bulk.go index eb8d9bc..c044935 100644 --- a/bulk.go +++ b/bulk.go @@ -45,6 +45,11 @@ func (conn *bulkconn) Bulk(context.Context, func(Conn) error) error { return ErrNotImplemented } +// Subscribe is not supported for bulk connections. +func (conn *bulkconn) Subscribe(context.Context, string, func(Notification) error) error { + return ErrNotAvailable.With("subscribe requires pool-backed connection") +} + // Execute a query func (conn *bulkconn) Exec(context.Context, string) error { return ErrNotImplemented diff --git a/conn.go b/conn.go index 6e1fc52..df1220f 100644 --- a/conn.go +++ b/conn.go @@ -28,6 +28,11 @@ type Conn interface { // should be in a transaction) Bulk(context.Context, func(Conn) error) error + // Subscribe to a PostgreSQL notification channel. The callback is invoked + // serially for each payload until the context is cancelled, the pool is + // closed, or the callback returns an error. + Subscribe(context.Context, string, func(Notification) error) error + // Execute a query Exec(context.Context, string) error @@ -150,6 +155,12 @@ func (p *conn) Bulk(ctx context.Context, fn func(Conn) error) error { return bulk(ctx, p.conn, p.bind, fn) } +// Subscribe requires a pool-backed connection so the listener lifecycle can be +// managed independently of transactions. +func (p *conn) Subscribe(context.Context, string, func(Notification) error) error { + return ErrNotAvailable.With("subscribe requires pool-backed connection") +} + // Execute a query func (p *conn) Exec(ctx context.Context, query string) error { return p.bind.exec(ctx, p.conn, query) diff --git a/listener.go b/listener.go index 10fd057..1b151bb 100644 --- a/listener.go +++ b/listener.go @@ -2,12 +2,14 @@ package pg import ( "context" + "errors" "fmt" + "strings" "sync" // Packages - types "github.com/mutablelogic/go-pg/pkg/types" pgxpool "github.com/jackc/pgx/v5/pgxpool" + types "github.com/mutablelogic/go-pg/pkg/types" ) //////////////////////////////////////////////////////////////////////////////// @@ -41,6 +43,15 @@ type Notification struct { Payload []byte } +type subscriptionGroup struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + mu sync.Mutex + closed bool + once sync.Once +} + //////////////////////////////////////////////////////////////////////////////// // LIFECYCLE @@ -52,6 +63,11 @@ func (pg *poolconn) Listener() Listener { return l } +func newSubscriptionGroup() *subscriptionGroup { + ctx, cancel := context.WithCancel(context.Background()) + return &subscriptionGroup{ctx: ctx, cancel: cancel} +} + // Close the connection to the database func (l *listener) Close(ctx context.Context) error { l.Lock() @@ -75,9 +91,86 @@ func (l *listener) Close(ctx context.Context) error { return err } +func (g *subscriptionGroup) Go(ctx context.Context, fn func(context.Context)) error { + g.mu.Lock() + if g.closed { + g.mu.Unlock() + return ErrNotAvailable.With("subscriptions are closed") + } + g.wg.Add(1) + parent := g.ctx + g.mu.Unlock() + + go func() { + defer g.wg.Done() + runCtx, cancel := context.WithCancel(parent) + stop := context.AfterFunc(ctx, cancel) + defer stop() + defer cancel() + fn(runCtx) + }() + + return nil +} + +func (g *subscriptionGroup) Close() { + g.once.Do(func() { + g.mu.Lock() + g.closed = true + g.mu.Unlock() + g.cancel() + g.wg.Wait() + }) +} + //////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS +func subscribe(ctx context.Context, pg *poolconn, channel string, fn func(Notification) error) error { + channel = strings.TrimSpace(channel) + if channel == "" { + return ErrBadParameter.With("channel is required") + } + if fn == nil { + return ErrBadParameter.With("callback is required") + } + + group := pg.conn.subscriptionGroup() + if group == nil { + return ErrNotAvailable.With("subscriptions are unavailable") + } + + listener := pg.Listener() + if listener == nil { + return ErrNotAvailable.With("listener is nil") + } + if err := listener.Listen(ctx, channel); err != nil { + return err + } + if err := group.Go(ctx, func(runCtx context.Context) { + defer listener.Unlisten(context.Background(), channel) + defer listener.Close(context.Background()) + + for { + n, err := listener.WaitForNotification(runCtx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + return + } + if err := fn(*n); err != nil { + return + } + } + }); err != nil { + err = errors.Join(err, listener.Unlisten(context.Background(), channel)) + return errors.Join(err, listener.Close(context.Background())) + } + + return nil +} + // Connect to the database, and listen to a topic func (l *listener) Listen(ctx context.Context, topic string) error { l.Lock() diff --git a/pool.go b/pool.go index e328947..d149451 100644 --- a/pool.go +++ b/pool.go @@ -4,6 +4,7 @@ import ( "context" "errors" "strings" + "sync" // Packages pgx "github.com/jackc/pgx/v5" @@ -32,6 +33,8 @@ type PoolConn interface { type pool struct { *pgxpool.Pool + mu sync.Mutex + subscriptions *subscriptionGroup } type poolconn struct { @@ -82,7 +85,7 @@ func NewPool(ctx context.Context, opts ...Opt) (PoolConn, error) { } // Wrap the connection pool as if it's a transaction - return &poolconn{&pool{p}, o.bind}, nil + return &poolconn{&pool{Pool: p, subscriptions: newSubscriptionGroup()}, o.bind}, nil } //////////////////////////////////////////////////////////////////////////////// @@ -116,10 +119,12 @@ func (p *poolconn) Ping(ctx context.Context) error { } func (p *poolconn) Close() { + p.conn.resetSubscriptions() p.conn.Pool.Close() } func (p *poolconn) Reset() { + p.conn.resetSubscriptions() p.conn.Pool.Reset() } @@ -148,6 +153,11 @@ func (p *poolconn) Bulk(ctx context.Context, fn func(conn Conn) error) error { return bulk(ctx, p.conn, p.bind, fn) } +// Subscribe to a PostgreSQL notification channel using a dedicated connection. +func (p *poolconn) Subscribe(ctx context.Context, channel string, fn func(Notification) error) error { + return subscribe(ctx, p, channel, fn) +} + // Execute a query func (p *poolconn) Exec(ctx context.Context, query string) error { return p.bind.exec(ctx, p.conn, query) @@ -177,3 +187,19 @@ func (p *poolconn) Get(ctx context.Context, reader Reader, sel Selector) error { func (p *poolconn) List(ctx context.Context, reader Reader, sel Selector) error { return list(ctx, p.conn, p.bind, reader, sel) } + +func (p *pool) subscriptionGroup() *subscriptionGroup { + p.mu.Lock() + defer p.mu.Unlock() + return p.subscriptions +} + +func (p *pool) resetSubscriptions() { + p.mu.Lock() + group := p.subscriptions + p.subscriptions = newSubscriptionGroup() + p.mu.Unlock() + if group != nil { + group.Close() + } +} diff --git a/pool_test.go b/pool_test.go index 5f6249e..80c7f63 100644 --- a/pool_test.go +++ b/pool_test.go @@ -4,12 +4,16 @@ import ( "context" "encoding/json" "fmt" + "os" + "slices" "testing" + "time" // Packages pg "github.com/mutablelogic/go-pg" test "github.com/mutablelogic/go-pg/pkg/test" assert "github.com/stretchr/testify/assert" + require "github.com/stretchr/testify/require" ) // Global connection variable @@ -129,6 +133,92 @@ func Test_Pool_003(t *testing.T) { assert.NoError(err) } +func Test_Pool_004(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + conn := conn.Begin(t) + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + channel := fmt.Sprintf("test_pool_subscribe_%d", time.Now().UnixNano()) + notifyCh := make(chan pg.Notification, 1) + + require.NoError(conn.Subscribe(ctx, channel, func(n pg.Notification) error { + notifyCh <- pg.Notification{Channel: n.Channel, Payload: append([]byte(nil), n.Payload...)} + return nil + })) + require.NoError(conn.Exec(context.Background(), fmt.Sprintf("SELECT pg_notify('%s', 'hello')", channel))) + + select { + case notify := <-notifyCh: + assert.Equal(channel, notify.Channel) + assert.Equal([]byte("hello"), notify.Payload) + case <-ctx.Done(): + t.Fatal("timeout waiting for notification") + } + + require.ErrorIs(conn.Tx(context.Background(), func(tx pg.Conn) error { + return tx.Subscribe(context.Background(), channel, func(pg.Notification) error { return nil }) + }), pg.ErrNotAvailable) + + require.ErrorIs(conn.Bulk(context.Background(), func(tx pg.Conn) error { + return tx.Subscribe(context.Background(), channel, func(pg.Notification) error { return nil }) + }), pg.ErrNotAvailable) +} + +func Test_Pool_005(t *testing.T) { + require := require.New(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + verbose := slices.Contains(os.Args, "-test.v=true") + container, pool, err := test.NewPgxContainer(ctx, t.Name(), verbose, nil) + require.NoError(err) + defer container.Close(context.Background()) + + channel := fmt.Sprintf("test_pool_close_%d", time.Now().UnixNano()) + started := make(chan struct{}) + release := make(chan struct{}) + closed := make(chan struct{}) + + require.NoError(pool.Subscribe(context.Background(), channel, func(n pg.Notification) error { + require.Equal(channel, n.Channel) + require.Equal([]byte("hello"), n.Payload) + close(started) + <-release + return nil + })) + require.NoError(pool.Exec(context.Background(), fmt.Sprintf("SELECT pg_notify('%s', 'hello')", channel))) + + select { + case <-started: + case <-ctx.Done(): + t.Fatal("timeout waiting for callback") + } + + go func() { + pool.Close() + close(closed) + }() + + select { + case <-closed: + t.Fatal("pool.Close returned before subscriber completed") + case <-time.After(200 * time.Millisecond): + } + + close(release) + + select { + case <-closed: + case <-ctx.Done(): + t.Fatal("timeout waiting for pool close") + } +} + //////////////////////////////////////////////////////////////////////////////// type Test struct { From 19745d1bd3bbb44a2d4f863c0382087425aa5ef3 Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Mon, 23 Mar 2026 14:12:15 +0100 Subject: [PATCH 2/7] Updated README --- README.md | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index c8d1cc2..3b1ec31 100644 --- a/README.md +++ b/README.md @@ -401,30 +401,27 @@ the transaction will be committed. Transactions can be nested. ## Notify and Listen -PostgreSQL supports asynchronous notifications via `NOTIFY` and `LISTEN`. Obtain a `Listener` from the pool and call `Listen` to subscribe: +PostgreSQL supports asynchronous notifications via `NOTIFY` and `LISTEN`. The preferred API is `Subscribe`, which acquires a dedicated listener connection from the pool, invokes a callback for each notification, and automatically unsubscribes when the context is cancelled or the pool is closed. ```go import pg "github.com/mutablelogic/go-pg" -// Create a listener -listener := pool.Listener() -defer listener.Close(ctx) - -// Subscribe to a channel -if err := listener.Listen(ctx, "my_channel"); err != nil { +// Subscribe to a channel using a pool-backed connection. +if err := pool.Subscribe(ctx, "my_channel", func(n pg.Notification) error { + fmt.Printf("Channel: %s, Payload: %s\n", n.Channel, n.Payload) + return nil +}); err != nil { panic(err) } -// Wait for notifications -for { - notification, err := listener.WaitForNotification(ctx) - if err != nil { - return - } - fmt.Printf("Channel: %s, Payload: %s\n", notification.Channel, notification.Payload) -} +// Block until shutdown. +<-ctx.Done() ``` +Subscriptions are long-lived and tied to a dedicated PostgreSQL session, so they are only supported on pool-backed connections. Calling `Subscribe` from a transactional or bulk connection returns `pg.ErrNotAvailable`. + +When `pool.Close()` is called, all active subscriptions are cancelled and the pool waits for their callbacks to exit before returning. + To send a notification from another connection: ```go @@ -433,6 +430,8 @@ if err := pool.Exec(ctx, `NOTIFY my_channel, 'hello world'`); err != nil { } ``` +The lower-level `Listener` API still exists for compatibility, but `Subscribe` is the recommended interface for new code. + ## Schema Support The package provides convenience functions for managing PostgreSQL schemas: From 5e410632b2cfdca41d1d415fd58b7fe9d9ed3d3a Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Mon, 23 Mar 2026 14:22:25 +0100 Subject: [PATCH 3/7] Updates --- README.md | 47 +++++++++- conn.go | 9 +- err.go | 64 -------------- error.go | 235 ++++++++++++++++++++++++++++++++++++++++++++++++++ error_test.go | 66 ++++++++++++++ listener.go | 74 ++++++++++++++-- pool.go | 2 + pool_test.go | 39 +++++++++ 8 files changed, 462 insertions(+), 74 deletions(-) delete mode 100644 err.go create mode 100644 error.go create mode 100644 error_test.go diff --git a/README.md b/README.md index 3b1ec31..82f9aff 100644 --- a/README.md +++ b/README.md @@ -420,6 +420,8 @@ if err := pool.Subscribe(ctx, "my_channel", func(n pg.Notification) error { Subscriptions are long-lived and tied to a dedicated PostgreSQL session, so they are only supported on pool-backed connections. Calling `Subscribe` from a transactional or bulk connection returns `pg.ErrNotAvailable`. +`Subscribe` returns setup errors only. After registration, the subscription runs in the background and stops if the callback returns an error or if the listener encounters a non-context error such as a dropped connection. + When `pool.Close()` is called, all active subscriptions are cancelled and the pool waits for their callbacks to exit before returning. To send a notification from another connection: @@ -449,9 +451,46 @@ err := pg.SchemaCreate(ctx, conn, "myschema") err := pg.SchemaDrop(ctx, conn, "myschema") ``` -## Error Handing and Tracing +## Error Handling and Tracing + +The package provides typed errors for common PostgreSQL conditions. Most query helpers already +normalize driver errors before returning them, but `pg.NormalizeError` is available if you need to +normalize a raw `pgx` or `pgconn` error yourself. + +Common checks are: + +* `pg.ErrNotFound` for `pgx.ErrNoRows` +* `pg.ErrConflict` and `pg.ErrUniqueViolation` for SQLSTATE `23505` +* `pg.ErrBadParameter` for common input and constraint errors such as foreign key, not null, + check constraint and invalid text/date formats +* `pg.ErrDatabase` for any PostgreSQL error with a SQLSTATE code + +```go +import ( + "errors" + + pg "github.com/mutablelogic/go-pg" +) + +err := conn.Update(ctx, &obj, req, req) +err = pg.NormalizeError(err) + +switch { +case errors.Is(err, pg.ErrNotFound): + // Row not found +case errors.Is(err, pg.ErrConflict): + // Duplicate key or other conflict +case errors.Is(err, pg.ErrBadParameter): + // Invalid user-supplied data +case errors.Is(err, pg.ErrDatabase): + // Other PostgreSQL error + log.Printf("sqlstate=%s err=%v", pg.SQLState(err), err) +case err != nil: + // Non-database error +} +``` -The package provides typed errors for common PostgreSQL conditions: +If you need the specific PostgreSQL code for logging or translation, use `pg.SQLState(err)`. ```go import pg "github.com/mutablelogic/go-pg" @@ -459,8 +498,12 @@ import pg "github.com/mutablelogic/go-pg" if err := conn.Get(ctx, &obj, req); err != nil { if errors.Is(err, pg.ErrNotFound) { // Row not found + } else if errors.Is(err, pg.ErrConflict) { + // Conflict } else if errors.Is(err, pg.ErrBadParameter) { // Invalid parameter + } else if errors.Is(err, pg.ErrDatabase) { + log.Printf("postgres error: sqlstate=%s err=%v", pg.SQLState(err), err) } else { // Other error } diff --git a/conn.go b/conn.go index df1220f..b6e8a8f 100644 --- a/conn.go +++ b/conn.go @@ -28,9 +28,12 @@ type Conn interface { // should be in a transaction) Bulk(context.Context, func(Conn) error) error - // Subscribe to a PostgreSQL notification channel. The callback is invoked - // serially for each payload until the context is cancelled, the pool is - // closed, or the callback returns an error. + // Subscribe to a PostgreSQL notification channel. Subscribe returns setup + // errors only; after successful registration the subscription runs in the + // background. The callback is invoked serially for each payload until the + // context is cancelled, the pool is closed, the callback returns an error, + // or the listener stops because WaitForNotification returns a non-context + // error such as a dropped connection. Subscribe(context.Context, string, func(Notification) error) error // Execute a query diff --git a/err.go b/err.go deleted file mode 100644 index ad4aa66..0000000 --- a/err.go +++ /dev/null @@ -1,64 +0,0 @@ -package pg - -import ( - "errors" - "fmt" - - // Packages - pgx "github.com/jackc/pgx/v5" -) - -///////////////////////////////////////////////////////////////////// -// TYPES - -type Err int - -///////////////////////////////////////////////////////////////////// -// GLOBALS - -const ( - ErrSuccess Err = iota - ErrNotFound - ErrNotImplemented - ErrBadParameter - ErrNotAvailable -) - -// Error returns the string representation of the error. -func (e Err) Error() string { - switch e { - case ErrSuccess: - return "success" - case ErrNotFound: - return "not found" - case ErrNotImplemented: - return "not implemented" - case ErrBadParameter: - return "bad parameter" - case ErrNotAvailable: - return "not available" - default: - return fmt.Sprint("Unknown error ", int(e)) - } -} - -// With returns the error with additional context appended. -func (e Err) With(a ...any) error { - return fmt.Errorf("%w: %s", e, fmt.Sprint(a...)) -} - -// Withf returns the error with formatted context appended. -func (e Err) Withf(format string, a ...any) error { - return fmt.Errorf("%w: %s", e, fmt.Sprintf(format, a...)) -} - -///////////////////////////////////////////////////////////////////// -// PUBLIC METHODS - -func pgerror(err error) error { - if errors.Is(err, pgx.ErrNoRows) { - return ErrNotFound - } else { - return err - } -} diff --git a/error.go b/error.go new file mode 100644 index 0000000..44e37f6 --- /dev/null +++ b/error.go @@ -0,0 +1,235 @@ +package pg + +import ( + "errors" + "fmt" + + // Packages + pgx "github.com/jackc/pgx/v5" + pgconn "github.com/jackc/pgx/v5/pgconn" +) + +///////////////////////////////////////////////////////////////////// +// TYPES + +type Err int + +type DatabaseError struct { + code string + message string + err error + kinds []Err +} + +///////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + ErrSuccess Err = iota + ErrNotFound + ErrNotImplemented + ErrBadParameter + ErrNotAvailable + ErrConflict + ErrDatabase + ErrUniqueViolation + ErrForeignKeyViolation + ErrNotNullViolation + ErrCheckViolation + ErrInvalidTextRepresentation + ErrInvalidDatetimeFormat + ErrDatetimeFieldOverflow +) + +const ( + sqlStateUniqueViolation = "23505" + sqlStateForeignKeyViolation = "23503" + sqlStateNotNullViolation = "23502" + sqlStateCheckViolation = "23514" + sqlStateInvalidTextRepresentation = "22P02" + sqlStateInvalidDatetimeFormat = "22007" + sqlStateDatetimeFieldOverflow = "22008" +) + +// Error returns the string representation of the error. +func (e Err) Error() string { + switch e { + case ErrSuccess: + return "success" + case ErrNotFound: + return "not found" + case ErrNotImplemented: + return "not implemented" + case ErrBadParameter: + return "bad parameter" + case ErrNotAvailable: + return "not available" + case ErrConflict: + return "conflict" + case ErrDatabase: + return "database error" + case ErrUniqueViolation: + return "unique violation" + case ErrForeignKeyViolation: + return "foreign key violation" + case ErrNotNullViolation: + return "not null violation" + case ErrCheckViolation: + return "check violation" + case ErrInvalidTextRepresentation: + return "invalid text representation" + case ErrInvalidDatetimeFormat: + return "invalid datetime format" + case ErrDatetimeFieldOverflow: + return "datetime field overflow" + default: + return fmt.Sprint("Unknown error ", int(e)) + } +} + +// With returns the error with additional context appended. +func (e Err) With(a ...any) error { + return fmt.Errorf("%w: %s", e, fmt.Sprint(a...)) +} + +// Withf returns the error with formatted context appended. +func (e Err) Withf(format string, a ...any) error { + return fmt.Errorf("%w: %s", e, fmt.Sprintf(format, a...)) +} + +// Error returns the wrapped database error string. +func (e *DatabaseError) Error() string { + if e == nil { + return "" + } + if e.err != nil { + return e.err.Error() + } + if e.message != "" { + return e.message + } + return e.code +} + +// Unwrap returns the underlying driver error. +func (e *DatabaseError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +// Is supports errors.Is checks against broad and specific package errors. +func (e *DatabaseError) Is(target error) bool { + if e == nil { + return false + } + if kind, ok := target.(Err); ok { + for _, candidate := range e.kinds { + if candidate == kind { + return true + } + } + } + return errors.Is(e.err, target) +} + +// SQLState returns the PostgreSQL SQLSTATE code. +func (e *DatabaseError) SQLState() string { + if e == nil { + return "" + } + return e.code +} + +// Message returns the PostgreSQL error message. +func (e *DatabaseError) Message() string { + if e == nil { + return "" + } + return e.message +} + +///////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// NormalizeError maps driver-specific PostgreSQL errors to package errors. +func NormalizeError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, ErrNotFound) { + return err + } + var dbErr *DatabaseError + if errors.As(err, &dbErr) { + return err + } + if errors.Is(err, pgx.ErrNoRows) { + return ErrNotFound + } + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return newDatabaseError(pgErr) + } + return err +} + +// IsDatabaseError reports whether err is a PostgreSQL error with a SQLSTATE code. +func IsDatabaseError(err error) bool { + return SQLState(err) != "" +} + +// SQLState returns the PostgreSQL SQLSTATE code for err, if one is available. +func SQLState(err error) string { + if err == nil { + return "" + } + var dbErr *DatabaseError + if errors.As(err, &dbErr) { + return dbErr.SQLState() + } + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return pgErr.Code + } + return "" +} + +func pgerror(err error) error { + return NormalizeError(err) +} + +///////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func newDatabaseError(err *pgconn.PgError) error { + if err == nil { + return nil + } + + kinds := []Err{ErrDatabase} + switch err.Code { + case sqlStateUniqueViolation: + kinds = append(kinds, ErrConflict, ErrUniqueViolation) + case sqlStateForeignKeyViolation: + kinds = append(kinds, ErrBadParameter, ErrForeignKeyViolation) + case sqlStateNotNullViolation: + kinds = append(kinds, ErrBadParameter, ErrNotNullViolation) + case sqlStateCheckViolation: + kinds = append(kinds, ErrBadParameter, ErrCheckViolation) + case sqlStateInvalidTextRepresentation: + kinds = append(kinds, ErrBadParameter, ErrInvalidTextRepresentation) + case sqlStateInvalidDatetimeFormat: + kinds = append(kinds, ErrBadParameter, ErrInvalidDatetimeFormat) + case sqlStateDatetimeFieldOverflow: + kinds = append(kinds, ErrBadParameter, ErrDatetimeFieldOverflow) + } + + return &DatabaseError{ + code: err.Code, + message: err.Message, + err: err, + kinds: kinds, + } +} \ No newline at end of file diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..4f110cb --- /dev/null +++ b/error_test.go @@ -0,0 +1,66 @@ +package pg + +import ( + "testing" + + // Packages + pgx "github.com/jackc/pgx/v5" + pgconn "github.com/jackc/pgx/v5/pgconn" + assert "github.com/stretchr/testify/assert" + require "github.com/stretchr/testify/require" +) + +func Test_NormalizeError_001(t *testing.T) { + assert := assert.New(t) + + err := NormalizeError(pgx.ErrNoRows) + assert.ErrorIs(err, ErrNotFound) + assert.False(IsDatabaseError(err)) + assert.Empty(SQLState(err)) +} + +func Test_NormalizeError_002(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + err := NormalizeError(&pgconn.PgError{Code: sqlStateUniqueViolation, Message: "duplicate key value violates unique constraint"}) + assert.ErrorIs(err, ErrDatabase) + assert.ErrorIs(err, ErrConflict) + assert.ErrorIs(err, ErrUniqueViolation) + assert.NotErrorIs(err, ErrBadParameter) + assert.True(IsDatabaseError(err)) + assert.Equal(sqlStateUniqueViolation, SQLState(err)) + + var dbErr *DatabaseError + require.ErrorAs(err, &dbErr) + assert.Equal(sqlStateUniqueViolation, dbErr.SQLState()) + assert.Equal("duplicate key value violates unique constraint", dbErr.Message()) + assert.ErrorAs(err, new(*pgconn.PgError)) +} + +func Test_NormalizeError_003(t *testing.T) { + tests := []struct { + name string + code string + kind Err + }{ + {"foreign_key", sqlStateForeignKeyViolation, ErrForeignKeyViolation}, + {"not_null", sqlStateNotNullViolation, ErrNotNullViolation}, + {"check", sqlStateCheckViolation, ErrCheckViolation}, + {"invalid_text", sqlStateInvalidTextRepresentation, ErrInvalidTextRepresentation}, + {"invalid_datetime", sqlStateInvalidDatetimeFormat, ErrInvalidDatetimeFormat}, + {"datetime_overflow", sqlStateDatetimeFieldOverflow, ErrDatetimeFieldOverflow}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + + err := NormalizeError(&pgconn.PgError{Code: tc.code, Message: tc.name}) + assert.ErrorIs(err, ErrDatabase) + assert.ErrorIs(err, ErrBadParameter) + assert.ErrorIs(err, tc.kind) + assert.Equal(tc.code, SQLState(err)) + }) + } + } \ No newline at end of file diff --git a/listener.go b/listener.go index 1b151bb..871733e 100644 --- a/listener.go +++ b/listener.go @@ -4,8 +4,11 @@ import ( "context" "errors" "fmt" + "runtime" + "strconv" "strings" "sync" + "time" // Packages pgxpool "github.com/jackc/pgx/v5/pgxpool" @@ -38,6 +41,8 @@ type listener struct { var _ Listener = (*listener)(nil) +const subscriptionCleanupTimeout = 5 * time.Second + type Notification struct { Channel string Payload []byte @@ -50,6 +55,7 @@ type subscriptionGroup struct { mu sync.Mutex closed bool once sync.Once + active map[uint64]struct{} } //////////////////////////////////////////////////////////////////////////////// @@ -65,7 +71,7 @@ func (pg *poolconn) Listener() Listener { func newSubscriptionGroup() *subscriptionGroup { ctx, cancel := context.WithCancel(context.Background()) - return &subscriptionGroup{ctx: ctx, cancel: cancel} + return &subscriptionGroup{ctx: ctx, cancel: cancel, active: make(map[uint64]struct{})} } // Close the connection to the database @@ -103,6 +109,8 @@ func (g *subscriptionGroup) Go(ctx context.Context, fn func(context.Context)) er go func() { defer g.wg.Done() + g.enter() + defer g.leave() runCtx, cancel := context.WithCancel(parent) stop := context.AfterFunc(ctx, cancel) defer stop() @@ -119,6 +127,9 @@ func (g *subscriptionGroup) Close() { g.closed = true g.mu.Unlock() g.cancel() + if g.containsCurrentGoroutine() { + return + } g.wg.Wait() }) } @@ -148,8 +159,9 @@ func subscribe(ctx context.Context, pg *poolconn, channel string, fn func(Notifi return err } if err := group.Go(ctx, func(runCtx context.Context) { - defer listener.Unlisten(context.Background(), channel) - defer listener.Close(context.Background()) + defer func() { + cleanupListener(listener, channel) + }() for { n, err := listener.WaitForNotification(runCtx) @@ -164,13 +176,65 @@ func subscribe(ctx context.Context, pg *poolconn, channel string, fn func(Notifi } } }); err != nil { - err = errors.Join(err, listener.Unlisten(context.Background(), channel)) - return errors.Join(err, listener.Close(context.Background())) + return errors.Join(err, cleanupListener(listener, channel)) } return nil } +func cleanupListener(listener Listener, channel string) error { + ctx, cancel := context.WithTimeout(context.Background(), subscriptionCleanupTimeout) + defer cancel() + + err := listener.Unlisten(ctx, channel) + return errors.Join(err, listener.Close(ctx)) +} + +func (g *subscriptionGroup) enter() { + id, ok := currentGoroutineID() + if !ok { + return + } + g.mu.Lock() + g.active[id] = struct{}{} + g.mu.Unlock() +} + +func (g *subscriptionGroup) leave() { + id, ok := currentGoroutineID() + if !ok { + return + } + g.mu.Lock() + delete(g.active, id) + g.mu.Unlock() +} + +func (g *subscriptionGroup) containsCurrentGoroutine() bool { + id, ok := currentGoroutineID() + if !ok { + return false + } + g.mu.Lock() + _, exists := g.active[id] + g.mu.Unlock() + return exists +} + +func currentGoroutineID() (uint64, bool) { + var buf [64]byte + n := runtime.Stack(buf[:], false) + fields := strings.Fields(string(buf[:n])) + if len(fields) < 2 || fields[0] != "goroutine" { + return 0, false + } + id, err := strconv.ParseUint(fields[1], 10, 64) + if err != nil { + return 0, false + } + return id, true +} + // Connect to the database, and listen to a topic func (l *listener) Listen(ctx context.Context, topic string) error { l.Lock() diff --git a/pool.go b/pool.go index d149451..630cb43 100644 --- a/pool.go +++ b/pool.go @@ -154,6 +154,8 @@ func (p *poolconn) Bulk(ctx context.Context, fn func(conn Conn) error) error { } // Subscribe to a PostgreSQL notification channel using a dedicated connection. +// Subscribe returns only initial registration errors; runtime listener errors +// terminate the background subscription. func (p *poolconn) Subscribe(ctx context.Context, channel string, fn func(Notification) error) error { return subscribe(ctx, p, channel, fn) } diff --git a/pool_test.go b/pool_test.go index 80c7f63..332d031 100644 --- a/pool_test.go +++ b/pool_test.go @@ -181,6 +181,7 @@ func Test_Pool_005(t *testing.T) { channel := fmt.Sprintf("test_pool_close_%d", time.Now().UnixNano()) started := make(chan struct{}) + closingStarted := make(chan struct{}) release := make(chan struct{}) closed := make(chan struct{}) @@ -200,10 +201,17 @@ func Test_Pool_005(t *testing.T) { } go func() { + close(closingStarted) pool.Close() close(closed) }() + select { + case <-closingStarted: + case <-ctx.Done(): + t.Fatal("timeout waiting for pool close to start") + } + select { case <-closed: t.Fatal("pool.Close returned before subscriber completed") @@ -219,6 +227,37 @@ func Test_Pool_005(t *testing.T) { } } +func Test_Pool_006(t *testing.T) { + require := require.New(t) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + verbose := slices.Contains(os.Args, "-test.v=true") + container, pool, err := test.NewPgxContainer(ctx, t.Name(), verbose, nil) + require.NoError(err) + defer container.Close(context.Background()) + + channel := fmt.Sprintf("test_pool_close_from_callback_%d", time.Now().UnixNano()) + callbackReturned := make(chan struct{}) + + require.NoError(pool.Subscribe(context.Background(), channel, func(n pg.Notification) error { + require.Equal(channel, n.Channel) + require.Equal([]byte("hello"), n.Payload) + + pool.Close() + close(callbackReturned) + return nil + })) + require.NoError(pool.Exec(context.Background(), fmt.Sprintf("SELECT pg_notify('%s', 'hello')", channel))) + + select { + case <-callbackReturned: + case <-ctx.Done(): + t.Fatal("timeout waiting for callback to return after pool.Close") + } +} + //////////////////////////////////////////////////////////////////////////////// type Test struct { From c773a1b9f556876bb7b7db6afe22e1c892026d1b Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Mon, 23 Mar 2026 14:28:16 +0100 Subject: [PATCH 4/7] Updated --- listener.go | 11 +++++--- pool.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 76 insertions(+), 10 deletions(-) diff --git a/listener.go b/listener.go index 871733e..5b5be23 100644 --- a/listener.go +++ b/listener.go @@ -127,13 +127,16 @@ func (g *subscriptionGroup) Close() { g.closed = true g.mu.Unlock() g.cancel() - if g.containsCurrentGoroutine() { - return - } - g.wg.Wait() }) } +func (g *subscriptionGroup) Wait() { + if g == nil { + return + } + g.wg.Wait() +} + //////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS diff --git a/pool.go b/pool.go index 630cb43..a9659b8 100644 --- a/pool.go +++ b/pool.go @@ -35,6 +35,9 @@ type pool struct { *pgxpool.Pool mu sync.Mutex subscriptions *subscriptionGroup + closingGroup *subscriptionGroup + closeOnce sync.Once + closeDone chan struct{} } type poolconn struct { @@ -85,7 +88,7 @@ func NewPool(ctx context.Context, opts ...Opt) (PoolConn, error) { } // Wrap the connection pool as if it's a transaction - return &poolconn{&pool{Pool: p, subscriptions: newSubscriptionGroup()}, o.bind}, nil + return &poolconn{&pool{Pool: p, subscriptions: newSubscriptionGroup(), closeDone: make(chan struct{})}, o.bind}, nil } //////////////////////////////////////////////////////////////////////////////// @@ -119,13 +122,11 @@ func (p *poolconn) Ping(ctx context.Context) error { } func (p *poolconn) Close() { - p.conn.resetSubscriptions() - p.conn.Pool.Close() + p.conn.close() } func (p *poolconn) Reset() { - p.conn.resetSubscriptions() - p.conn.Pool.Reset() + p.conn.reset() } // Return a new connection with new bound parameters @@ -196,7 +197,7 @@ func (p *pool) subscriptionGroup() *subscriptionGroup { return p.subscriptions } -func (p *pool) resetSubscriptions() { +func (p *pool) resetSubscriptions() *subscriptionGroup { p.mu.Lock() group := p.subscriptions p.subscriptions = newSubscriptionGroup() @@ -204,4 +205,66 @@ func (p *pool) resetSubscriptions() { if group != nil { group.Close() } + return group +} + +func (p *pool) close() { + shouldWait := true + + p.closeOnce.Do(func() { + group := p.resetSubscriptions() + p.setClosingGroup(group) + + finish := func() { + p.afterSubscriptions(group, p.Pool.Close) + p.setClosingGroup(nil) + close(p.closeDone) + } + + if group != nil && group.containsCurrentGoroutine() { + shouldWait = false + go finish() + return + } + + finish() + }) + + if shouldWait && !p.calledFromClosingGroup() { + <-p.closeDone + } +} + +func (p *pool) reset() { + group := p.resetSubscriptions() + if group != nil && group.containsCurrentGoroutine() { + go func() { + p.afterSubscriptions(group, p.Pool.Reset) + }() + return + } + p.afterSubscriptions(group, p.Pool.Reset) +} + +func (p *pool) calledFromClosingGroup() bool { + p.mu.Lock() + group := p.closingGroup + p.mu.Unlock() + if group == nil { + return false + } + return group.containsCurrentGoroutine() +} + +func (p *pool) setClosingGroup(group *subscriptionGroup) { + p.mu.Lock() + p.closingGroup = group + p.mu.Unlock() +} + +func (p *pool) afterSubscriptions(group *subscriptionGroup, fn func()) { + if group != nil { + group.Wait() + } + fn() } From 73ca01dcf468f8121a209fec67595e4594de97bd Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Mon, 23 Mar 2026 14:34:33 +0100 Subject: [PATCH 5/7] Updated --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 82f9aff..40904d0 100644 --- a/README.md +++ b/README.md @@ -468,6 +468,7 @@ Common checks are: ```go import ( "errors" + "log" pg "github.com/mutablelogic/go-pg" ) From 4d9d8f362ba3bd255d6cd8891e8f970d263bdf31 Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Mon, 23 Mar 2026 15:42:51 +0100 Subject: [PATCH 6/7] Updates --- README.md | 10 ++-- bulk.go | 2 +- conn.go | 13 ++--- listener.go | 130 ++++++---------------------------------------- pkg/queue/task.go | 6 ++- pool.go | 95 ++++++++++++--------------------- pool_test.go | 56 +++++--------------- subscription.go | 82 +++++++++++++++++++++++++++++ 8 files changed, 164 insertions(+), 230 deletions(-) create mode 100644 subscription.go diff --git a/README.md b/README.md index 40904d0..75c295e 100644 --- a/README.md +++ b/README.md @@ -404,10 +404,14 @@ the transaction will be committed. Transactions can be nested. PostgreSQL supports asynchronous notifications via `NOTIFY` and `LISTEN`. The preferred API is `Subscribe`, which acquires a dedicated listener connection from the pool, invokes a callback for each notification, and automatically unsubscribes when the context is cancelled or the pool is closed. ```go -import pg "github.com/mutablelogic/go-pg" +import ( + "context" + + pg "github.com/mutablelogic/go-pg" +) // Subscribe to a channel using a pool-backed connection. -if err := pool.Subscribe(ctx, "my_channel", func(n pg.Notification) error { +if err := pool.Subscribe(ctx, "my_channel", func(_ context.Context, n pg.Notification) error { fmt.Printf("Channel: %s, Payload: %s\n", n.Channel, n.Payload) return nil }); err != nil { @@ -420,7 +424,7 @@ if err := pool.Subscribe(ctx, "my_channel", func(n pg.Notification) error { Subscriptions are long-lived and tied to a dedicated PostgreSQL session, so they are only supported on pool-backed connections. Calling `Subscribe` from a transactional or bulk connection returns `pg.ErrNotAvailable`. -`Subscribe` returns setup errors only. After registration, the subscription runs in the background and stops if the callback returns an error or if the listener encounters a non-context error such as a dropped connection. +`Subscribe` returns setup errors only. After registration, the subscription runs in the background and stops if the callback returns an error or if the listener encounters a non-context error such as a dropped connection. The callback receives a context that is cancelled when the subscription is shutting down. When `pool.Close()` is called, all active subscriptions are cancelled and the pool waits for their callbacks to exit before returning. diff --git a/bulk.go b/bulk.go index c044935..bc119ab 100644 --- a/bulk.go +++ b/bulk.go @@ -46,7 +46,7 @@ func (conn *bulkconn) Bulk(context.Context, func(Conn) error) error { } // Subscribe is not supported for bulk connections. -func (conn *bulkconn) Subscribe(context.Context, string, func(Notification) error) error { +func (conn *bulkconn) Subscribe(context.Context, string, func(context.Context, Notification) error) error { return ErrNotAvailable.With("subscribe requires pool-backed connection") } diff --git a/conn.go b/conn.go index b6e8a8f..6c31b75 100644 --- a/conn.go +++ b/conn.go @@ -30,11 +30,12 @@ type Conn interface { // Subscribe to a PostgreSQL notification channel. Subscribe returns setup // errors only; after successful registration the subscription runs in the - // background. The callback is invoked serially for each payload until the - // context is cancelled, the pool is closed, the callback returns an error, - // or the listener stops because WaitForNotification returns a non-context - // error such as a dropped connection. - Subscribe(context.Context, string, func(Notification) error) error + // background. The callback is invoked serially for each payload with a + // context that is cancelled when the subscription stops. The subscription + // runs until the context is cancelled, the pool is closed, the callback + // returns an error, or the listener stops because WaitForNotification + // returns a non-context error such as a dropped connection. + Subscribe(context.Context, string, func(context.Context, Notification) error) error // Execute a query Exec(context.Context, string) error @@ -160,7 +161,7 @@ func (p *conn) Bulk(ctx context.Context, fn func(Conn) error) error { // Subscribe requires a pool-backed connection so the listener lifecycle can be // managed independently of transactions. -func (p *conn) Subscribe(context.Context, string, func(Notification) error) error { +func (p *conn) Subscribe(context.Context, string, func(context.Context, Notification) error) error { return ErrNotAvailable.With("subscribe requires pool-backed connection") } diff --git a/listener.go b/listener.go index 5b5be23..088045b 100644 --- a/listener.go +++ b/listener.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "runtime" - "strconv" "strings" "sync" "time" @@ -48,16 +46,6 @@ type Notification struct { Payload []byte } -type subscriptionGroup struct { - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - mu sync.Mutex - closed bool - once sync.Once - active map[uint64]struct{} -} - //////////////////////////////////////////////////////////////////////////////// // LIFECYCLE @@ -69,11 +57,6 @@ func (pg *poolconn) Listener() Listener { return l } -func newSubscriptionGroup() *subscriptionGroup { - ctx, cancel := context.WithCancel(context.Background()) - return &subscriptionGroup{ctx: ctx, cancel: cancel, active: make(map[uint64]struct{})} -} - // Close the connection to the database func (l *listener) Close(ctx context.Context) error { l.Lock() @@ -97,50 +80,10 @@ func (l *listener) Close(ctx context.Context) error { return err } -func (g *subscriptionGroup) Go(ctx context.Context, fn func(context.Context)) error { - g.mu.Lock() - if g.closed { - g.mu.Unlock() - return ErrNotAvailable.With("subscriptions are closed") - } - g.wg.Add(1) - parent := g.ctx - g.mu.Unlock() - - go func() { - defer g.wg.Done() - g.enter() - defer g.leave() - runCtx, cancel := context.WithCancel(parent) - stop := context.AfterFunc(ctx, cancel) - defer stop() - defer cancel() - fn(runCtx) - }() - - return nil -} - -func (g *subscriptionGroup) Close() { - g.once.Do(func() { - g.mu.Lock() - g.closed = true - g.mu.Unlock() - g.cancel() - }) -} - -func (g *subscriptionGroup) Wait() { - if g == nil { - return - } - g.wg.Wait() -} - //////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS -func subscribe(ctx context.Context, pg *poolconn, channel string, fn func(Notification) error) error { +func subscribe(ctx context.Context, pg *poolconn, channel string, fn func(context.Context, Notification) error) error { channel = strings.TrimSpace(channel) if channel == "" { return ErrBadParameter.With("channel is required") @@ -149,19 +92,25 @@ func subscribe(ctx context.Context, pg *poolconn, channel string, fn func(Notifi return ErrBadParameter.With("callback is required") } - group := pg.conn.subscriptionGroup() - if group == nil { - return ErrNotAvailable.With("subscriptions are unavailable") - } - listener := pg.Listener() if listener == nil { - return ErrNotAvailable.With("listener is nil") + return ErrNotAvailable.With("subscriptions are unavailable") } if err := listener.Listen(ctx, channel); err != nil { return err } - if err := group.Go(ctx, func(runCtx context.Context) { + + runCtx, cancel := context.WithCancel(ctx) + sub, err := pg.conn.addSubscription(cancel) + if err != nil { + cancel() + return errors.Join(err, cleanupListener(listener, channel)) + } + + go func() { + defer sub.Done() + defer pg.conn.removeSubscription(sub) + defer cancel() defer func() { cleanupListener(listener, channel) }() @@ -174,13 +123,11 @@ func subscribe(ctx context.Context, pg *poolconn, channel string, fn func(Notifi } return } - if err := fn(*n); err != nil { + if err := fn(runCtx, *n); err != nil { return } } - }); err != nil { - return errors.Join(err, cleanupListener(listener, channel)) - } + }() return nil } @@ -193,51 +140,6 @@ func cleanupListener(listener Listener, channel string) error { return errors.Join(err, listener.Close(ctx)) } -func (g *subscriptionGroup) enter() { - id, ok := currentGoroutineID() - if !ok { - return - } - g.mu.Lock() - g.active[id] = struct{}{} - g.mu.Unlock() -} - -func (g *subscriptionGroup) leave() { - id, ok := currentGoroutineID() - if !ok { - return - } - g.mu.Lock() - delete(g.active, id) - g.mu.Unlock() -} - -func (g *subscriptionGroup) containsCurrentGoroutine() bool { - id, ok := currentGoroutineID() - if !ok { - return false - } - g.mu.Lock() - _, exists := g.active[id] - g.mu.Unlock() - return exists -} - -func currentGoroutineID() (uint64, bool) { - var buf [64]byte - n := runtime.Stack(buf[:], false) - fields := strings.Fields(string(buf[:n])) - if len(fields) < 2 || fields[0] != "goroutine" { - return 0, false - } - id, err := strconv.ParseUint(fields[1], 10, 64) - if err != nil { - return 0, false - } - return id, true -} - // Connect to the database, and listen to a topic func (l *listener) Listen(ctx context.Context, topic string) error { l.Lock() diff --git a/pkg/queue/task.go b/pkg/queue/task.go index d851c27..f757580 100644 --- a/pkg/queue/task.go +++ b/pkg/queue/task.go @@ -160,7 +160,11 @@ func (manager *Manager) runTaskLoop(ctx context.Context, sem chan int, handler T if listener == nil { return pg.ErrBadParameter.With("listener is nil") } - defer listener.Close(context.Background()) + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = listener.Close(shutdownCtx) + }() // Subscribe to queue insert notifications for this namespace topic := manager.ns + schema.TopicQueueInsert diff --git a/pool.go b/pool.go index a9659b8..2f62a0b 100644 --- a/pool.go +++ b/pool.go @@ -34,8 +34,7 @@ type PoolConn interface { type pool struct { *pgxpool.Pool mu sync.Mutex - subscriptions *subscriptionGroup - closingGroup *subscriptionGroup + subscriptions *subscriptionArray closeOnce sync.Once closeDone chan struct{} } @@ -88,7 +87,7 @@ func NewPool(ctx context.Context, opts ...Opt) (PoolConn, error) { } // Wrap the connection pool as if it's a transaction - return &poolconn{&pool{Pool: p, subscriptions: newSubscriptionGroup(), closeDone: make(chan struct{})}, o.bind}, nil + return &poolconn{&pool{Pool: p, subscriptions: newSubscriptionArray(), closeDone: make(chan struct{})}, o.bind}, nil } //////////////////////////////////////////////////////////////////////////////// @@ -157,7 +156,7 @@ func (p *poolconn) Bulk(ctx context.Context, fn func(conn Conn) error) error { // Subscribe to a PostgreSQL notification channel using a dedicated connection. // Subscribe returns only initial registration errors; runtime listener errors // terminate the background subscription. -func (p *poolconn) Subscribe(ctx context.Context, channel string, fn func(Notification) error) error { +func (p *poolconn) Subscribe(ctx context.Context, channel string, fn func(context.Context, Notification) error) error { return subscribe(ctx, p, channel, fn) } @@ -191,80 +190,52 @@ func (p *poolconn) List(ctx context.Context, reader Reader, sel Selector) error return list(ctx, p.conn, p.bind, reader, sel) } -func (p *pool) subscriptionGroup() *subscriptionGroup { +func (p *pool) addSubscription(cancel context.CancelFunc) (*subscription, error) { p.mu.Lock() defer p.mu.Unlock() - return p.subscriptions -} -func (p *pool) resetSubscriptions() *subscriptionGroup { - p.mu.Lock() - group := p.subscriptions - p.subscriptions = newSubscriptionGroup() - p.mu.Unlock() - if group != nil { - group.Close() + if p.subscriptions == nil { + return nil, ErrNotAvailable.With("subscriptions are unavailable") } - return group -} - -func (p *pool) close() { - shouldWait := true - - p.closeOnce.Do(func() { - group := p.resetSubscriptions() - p.setClosingGroup(group) - - finish := func() { - p.afterSubscriptions(group, p.Pool.Close) - p.setClosingGroup(nil) - close(p.closeDone) - } - - if group != nil && group.containsCurrentGoroutine() { - shouldWait = false - go finish() - return - } - - finish() - }) - if shouldWait && !p.calledFromClosingGroup() { - <-p.closeDone - } + return p.subscriptions.Add(cancel) } -func (p *pool) reset() { - group := p.resetSubscriptions() - if group != nil && group.containsCurrentGoroutine() { - go func() { - p.afterSubscriptions(group, p.Pool.Reset) - }() - return +func (p *pool) removeSubscription(sub *subscription) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.subscriptions != nil { + p.subscriptions.Remove(sub) } - p.afterSubscriptions(group, p.Pool.Reset) } -func (p *pool) calledFromClosingGroup() bool { +func (p *pool) swapSubscriptions(next *subscriptionArray) *subscriptionArray { p.mu.Lock() - group := p.closingGroup + current := p.subscriptions + p.subscriptions = next p.mu.Unlock() - if group == nil { - return false - } - return group.containsCurrentGoroutine() + return current } -func (p *pool) setClosingGroup(group *subscriptionGroup) { - p.mu.Lock() - p.closingGroup = group - p.mu.Unlock() +func (p *pool) close() { + p.closeOnce.Do(func() { + p.closeSubscriptions(p.swapSubscriptions(nil)) + p.Pool.Close() + close(p.closeDone) + }) + <-p.closeDone } -func (p *pool) afterSubscriptions(group *subscriptionGroup, fn func()) { +func (p *pool) reset() { + p.closeSubscriptions(p.swapSubscriptions(newSubscriptionArray())) + p.Pool.Reset() +} + +func (p *pool) closeSubscriptions(group *subscriptionArray) { if group != nil { - group.Wait() + for _, sub := range group.Close() { + sub.Wait() + } } - fn() } diff --git a/pool_test.go b/pool_test.go index 332d031..215b913 100644 --- a/pool_test.go +++ b/pool_test.go @@ -145,7 +145,8 @@ func Test_Pool_004(t *testing.T) { channel := fmt.Sprintf("test_pool_subscribe_%d", time.Now().UnixNano()) notifyCh := make(chan pg.Notification, 1) - require.NoError(conn.Subscribe(ctx, channel, func(n pg.Notification) error { + require.NoError(conn.Subscribe(ctx, channel, func(subCtx context.Context, n pg.Notification) error { + require.NoError(subCtx.Err()) notifyCh <- pg.Notification{Channel: n.Channel, Payload: append([]byte(nil), n.Payload...)} return nil })) @@ -160,11 +161,11 @@ func Test_Pool_004(t *testing.T) { } require.ErrorIs(conn.Tx(context.Background(), func(tx pg.Conn) error { - return tx.Subscribe(context.Background(), channel, func(pg.Notification) error { return nil }) + return tx.Subscribe(context.Background(), channel, func(context.Context, pg.Notification) error { return nil }) }), pg.ErrNotAvailable) require.ErrorIs(conn.Bulk(context.Background(), func(tx pg.Conn) error { - return tx.Subscribe(context.Background(), channel, func(pg.Notification) error { return nil }) + return tx.Subscribe(context.Background(), channel, func(context.Context, pg.Notification) error { return nil }) }), pg.ErrNotAvailable) } @@ -182,15 +183,17 @@ func Test_Pool_005(t *testing.T) { channel := fmt.Sprintf("test_pool_close_%d", time.Now().UnixNano()) started := make(chan struct{}) closingStarted := make(chan struct{}) - release := make(chan struct{}) + cancelled := make(chan struct{}) closed := make(chan struct{}) - require.NoError(pool.Subscribe(context.Background(), channel, func(n pg.Notification) error { + require.NoError(pool.Subscribe(context.Background(), channel, func(subCtx context.Context, n pg.Notification) error { + require.NoError(subCtx.Err()) require.Equal(channel, n.Channel) require.Equal([]byte("hello"), n.Payload) close(started) - <-release - return nil + <-subCtx.Done() + close(cancelled) + return subCtx.Err() })) require.NoError(pool.Exec(context.Background(), fmt.Sprintf("SELECT pg_notify('%s', 'hello')", channel))) @@ -213,13 +216,11 @@ func Test_Pool_005(t *testing.T) { } select { - case <-closed: - t.Fatal("pool.Close returned before subscriber completed") - case <-time.After(200 * time.Millisecond): + case <-cancelled: + case <-ctx.Done(): + t.Fatal("timeout waiting for callback cancellation") } - close(release) - select { case <-closed: case <-ctx.Done(): @@ -227,37 +228,6 @@ func Test_Pool_005(t *testing.T) { } } -func Test_Pool_006(t *testing.T) { - require := require.New(t) - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - - verbose := slices.Contains(os.Args, "-test.v=true") - container, pool, err := test.NewPgxContainer(ctx, t.Name(), verbose, nil) - require.NoError(err) - defer container.Close(context.Background()) - - channel := fmt.Sprintf("test_pool_close_from_callback_%d", time.Now().UnixNano()) - callbackReturned := make(chan struct{}) - - require.NoError(pool.Subscribe(context.Background(), channel, func(n pg.Notification) error { - require.Equal(channel, n.Channel) - require.Equal([]byte("hello"), n.Payload) - - pool.Close() - close(callbackReturned) - return nil - })) - require.NoError(pool.Exec(context.Background(), fmt.Sprintf("SELECT pg_notify('%s', 'hello')", channel))) - - select { - case <-callbackReturned: - case <-ctx.Done(): - t.Fatal("timeout waiting for callback to return after pool.Close") - } -} - //////////////////////////////////////////////////////////////////////////////// type Test struct { diff --git a/subscription.go b/subscription.go new file mode 100644 index 0000000..05565d8 --- /dev/null +++ b/subscription.go @@ -0,0 +1,82 @@ +package pg + +import ( + "context" + "sync" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type subscription struct { + cancel context.CancelFunc + done chan struct{} +} + +type subscriptionArray struct { + mu sync.Mutex + closed bool + items []*subscription +} + +//////////////////////////////////////////////////////////////////////////////// +// LIFECYCLE + +func newSubscriptionArray() *subscriptionArray { + return &subscriptionArray{} +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (a *subscriptionArray) Add(cancel context.CancelFunc) (*subscription, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.closed { + return nil, ErrNotAvailable.With("subscriptions are closed") + } + + sub := &subscription{cancel: cancel, done: make(chan struct{})} + a.items = append(a.items, sub) + return sub, nil +} + +func (a *subscriptionArray) Remove(sub *subscription) { + a.mu.Lock() + defer a.mu.Unlock() + + for i, item := range a.items { + if item != sub { + continue + } + a.items = append(a.items[:i], a.items[i+1:]...) + return + } +} + +func (a *subscriptionArray) Close() []*subscription { + a.mu.Lock() + if a.closed { + items := append([]*subscription(nil), a.items...) + a.mu.Unlock() + return items + } + a.closed = true + items := append([]*subscription(nil), a.items...) + a.mu.Unlock() + + for _, sub := range items { + sub.cancel() + } + + return items +} + +func (s *subscription) Done() { + close(s.done) +} + +func (s *subscription) Wait() { + <-s.done +} From 770540981ef8a29135171e4095af42fad11f1665 Mon Sep 17 00:00:00 2001 From: David Thorpe Date: Mon, 23 Mar 2026 15:52:20 +0100 Subject: [PATCH 7/7] Changes after PR's --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 75c295e..f470bc0 100644 --- a/README.md +++ b/README.md @@ -406,6 +406,7 @@ PostgreSQL supports asynchronous notifications via `NOTIFY` and `LISTEN`. The pr ```go import ( "context" + "fmt" pg "github.com/mutablelogic/go-pg" )