diff --git a/server/console_account.go b/server/console_account.go index 86791f8f1..0203e385c 100644 --- a/server/console_account.go +++ b/server/console_account.go @@ -21,6 +21,10 @@ import ( "encoding/base64" "encoding/gob" "encoding/json" + "regexp" + "strconv" + "strings" + "github.com/gofrs/uuid" "github.com/heroiclabs/nakama-common/api" "github.com/heroiclabs/nakama/v3/console" @@ -31,9 +35,6 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" - "regexp" - "strconv" - "strings" ) var validTrigramFilterRegex = regexp.MustCompile("^%?[^%]{3,}%?$") @@ -692,13 +693,7 @@ func (s *ConsoleServer) UpdateAccount(ctx context.Context, in *console.UpdateAcc return &emptypb.Empty{}, nil } - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - s.logger.Error("Could not begin database transaction.", zap.Error(err)) - return nil, status.Error(codes.Internal, "An error occurred while trying to update the user.") - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err = ExecuteInTx(ctx, s.db, func(tx *sql.Tx) error { for oldDeviceID, newDeviceID := range in.DeviceIds { if newDeviceID == "" { query := `DELETE FROM user_device WHERE id = $2 AND user_id = $1 diff --git a/server/console_group.go b/server/console_group.go index e71c5b1d7..a5a589d2a 100644 --- a/server/console_group.go +++ b/server/console_group.go @@ -317,13 +317,7 @@ func (s *ConsoleServer) DemoteGroupMember(ctx context.Context, in *console.Updat var message *api.ChannelMessage ts := time.Now().Unix() - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { query := "" if myState == 0 { // Ensure we aren't removing the last superadmin when deleting authoritatively. @@ -463,13 +457,7 @@ func (s *ConsoleServer) PromoteGroupMember(ctx context.Context, in *console.Upda var message *api.ChannelMessage ts := time.Now().Unix() - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { if uid == caller { return errors.New("cannot promote self") } diff --git a/server/console_unlink.go b/server/console_unlink.go index ab70432bc..fc456f8c2 100644 --- a/server/console_unlink.go +++ b/server/console_unlink.go @@ -16,6 +16,7 @@ package server import ( "context" + "database/sql" "github.com/gofrs/uuid" "github.com/heroiclabs/nakama/v3/console" @@ -96,13 +97,7 @@ func (s *ConsoleServer) UnlinkDevice(ctx context.Context, in *console.UnlinkDevi return nil, status.Error(codes.InvalidArgument, "Requires a valid device ID.") } - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - s.logger.Error("Could not begin database transaction.", zap.Error(err)) - return nil, status.Error(codes.Internal, "Could not unlink Device ID.") - } - - err = ExecuteInTx(ctx, tx, func() error { + err = ExecuteInTx(ctx, s.db, func(tx *sql.Tx) error { query := `DELETE FROM user_device WHERE id = $2 AND user_id = $1 AND (EXISTS (SELECT id FROM users WHERE id = $1 AND (apple_id IS NOT NULL diff --git a/server/core_account.go b/server/core_account.go index 25d08eabc..c9266daef 100644 --- a/server/core_account.go +++ b/server/core_account.go @@ -243,13 +243,7 @@ WHERE u.id IN (` + strings.Join(statements, ",") + `)` } func UpdateAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, updates []*accountUpdate) error { - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { updateErr := updateAccounts(ctx, logger, tx, updates) if updateErr != nil { return updateErr @@ -473,14 +467,8 @@ func ExportAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID u func DeleteAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config, leaderboardRankCache LeaderboardRankCache, sessionRegistry SessionRegistry, sessionCache SessionCache, tracker Tracker, userID uuid.UUID, recorded bool) error { ts := time.Now().UTC().Unix() - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - var deleted bool - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { count, err := DeleteUser(ctx, tx, userID) if err != nil { logger.Debug("Could not delete user", zap.Error(err), zap.String("user_id", userID.String())) @@ -520,11 +508,11 @@ func DeleteAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, config C if deleted { // Logout and disconnect. - if err = SessionLogout(config, sessionCache, userID, "", ""); err != nil { + if err := SessionLogout(config, sessionCache, userID, "", ""); err != nil { return err } for _, presence := range tracker.ListPresenceIDByStream(PresenceStream{Mode: StreamModeNotifications, Subject: userID}) { - if err = sessionRegistry.Disconnect(ctx, presence.SessionID, false); err != nil { + if err := sessionRegistry.Disconnect(ctx, presence.SessionID, false); err != nil { return err } } diff --git a/server/core_authenticate.go b/server/core_authenticate.go index cb9e483a2..97e54c49e 100644 --- a/server/core_authenticate.go +++ b/server/core_authenticate.go @@ -225,13 +225,7 @@ func AuthenticateDevice(ctx context.Context, logger *zap.Logger, db *sql.DB, dev // Create a new account. userID := uuid.Must(uuid.NewV4()).String() - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return "", "", false, status.Error(codes.Internal, "Error finding or creating user account.") - } - - err = ExecuteInTx(ctx, tx, func() error { + err = ExecuteInTx(ctx, db, func(tx *sql.Tx) error { query := ` INSERT INTO users (id, username, create_time, update_time) SELECT $1 AS id, @@ -848,13 +842,7 @@ func importSteamFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, mes } var friendUserIDs []uuid.UUID - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return status.Error(codes.Internal, "Error importing Steam friends.") - } - - err = ExecuteInTx(ctx, tx, func() error { + err = ExecuteInTx(ctx, db, func(tx *sql.Tx) error { if reset { if err := resetUserFriends(ctx, tx, userID); err != nil { logger.Error("Could not reset user friends", zap.Error(err)) @@ -930,13 +918,7 @@ func importFacebookFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, } var friendUserIDs []uuid.UUID - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return status.Error(codes.Internal, "Error importing Facebook friends.") - } - - err = ExecuteInTx(ctx, tx, func() error { + err = ExecuteInTx(ctx, db, func(tx *sql.Tx) error { if reset { if err := resetUserFriends(ctx, tx, userID); err != nil { logger.Error("Could not reset user friends", zap.Error(err)) diff --git a/server/core_friend.go b/server/core_friend.go index 184b2b61d..024f40e05 100644 --- a/server/core_friend.go +++ b/server/core_friend.go @@ -23,10 +23,11 @@ import ( "encoding/json" "errors" "fmt" - "github.com/heroiclabs/nakama-common/runtime" "strconv" "time" + "github.com/heroiclabs/nakama-common/runtime" + "github.com/gofrs/uuid" "github.com/heroiclabs/nakama-common/api" "github.com/jackc/pgtype" @@ -226,13 +227,7 @@ func AddFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, messageRout var notificationToSend map[string]bool - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { // If the transaction is retried ensure we wipe any notifications that may have been prepared by previous attempts. notificationToSend = make(map[string]bool) @@ -373,13 +368,7 @@ func DeleteFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, currentU uniqueFriendIDs[fid] = struct{}{} } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { for id := range uniqueFriendIDs { if deleteFriendErr := deleteFriend(ctx, logger, tx, currentUser, id); deleteFriendErr != nil { return deleteFriendErr @@ -428,13 +417,7 @@ func BlockFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, currentUs uniqueFriendIDs[fid] = struct{}{} } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { for id := range uniqueFriendIDs { if blockFriendErr := blockFriend(ctx, logger, tx, currentUser, id); blockFriendErr != nil { return blockFriendErr diff --git a/server/core_group.go b/server/core_group.go index 237619097..9a07e1348 100644 --- a/server/core_group.go +++ b/server/core_group.go @@ -106,13 +106,7 @@ RETURNING id, creator_id, name, description, avatar_url, state, edge_count, lang var group *api.Group - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return nil, err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { rows, err := tx.QueryContext(ctx, query, params...) if err != nil { var pgErr *pgconn.PgError @@ -273,13 +267,7 @@ func DeleteGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, groupID uu } } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { return deleteGroup(ctx, logger, tx, groupID) }); err != nil { logger.Error("Error deleting group.", zap.Error(err)) @@ -409,13 +397,7 @@ WHERE (id = $1) AND (disable_time = '1970-01-01 00:00:00 UTC')` GroupId: group.Id, } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err = ExecuteInTx(ctx, db, func(tx *sql.Tx) error { if _, err = groupAddUser(ctx, db, tx, uuid.Must(uuid.FromString(group.Id)), userID, state); err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation { @@ -524,13 +506,7 @@ func LeaveGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tra GroupId: groupID.String(), } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { query = "DELETE FROM group_edge WHERE (source_id = $1::UUID AND destination_id = $2::UUID) OR (source_id = $2::UUID AND destination_id = $1::UUID)" // don't need to check affectedRows as we've confirmed the existence of the relationship above if _, err = tx.ExecContext(ctx, query, groupID, userID); err != nil { @@ -641,13 +617,7 @@ func AddGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M ts := time.Now().Unix() var messages []*api.ChannelMessage - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { // If the transaction is retried ensure we wipe any notifications/messages that may have been prepared by previous attempts. notifications = make(map[uuid.UUID][]*api.Notification, len(userIDs)) messages = make([]*api.ChannelMessage, 0, len(userIDs)) @@ -800,13 +770,7 @@ func BanGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker var messages []*api.ChannelMessage kicked := make(map[uuid.UUID]struct{}, len(userIDs)) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { // If the transaction is retried ensure we wipe any messages that may have been prepared by previous attempts. messages = make([]*api.ChannelMessage, 0, len(userIDs)) // Position to use for new banned edges. @@ -991,13 +955,7 @@ func KickGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker var messages []*api.ChannelMessage kicked := make(map[uuid.UUID]struct{}, len(userIDs)) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { // If the transaction is retried ensure we wipe any messages that may have been prepared by previous attempts. messages = make([]*api.ChannelMessage, 0, len(userIDs)) @@ -1172,13 +1130,7 @@ func PromoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, rout ts := time.Now().Unix() var messages []*api.ChannelMessage - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { // If the transaction is retried ensure we wipe any messages that may have been prepared by previous attempts. messages = make([]*api.ChannelMessage, 0, len(userIDs)) @@ -1303,13 +1255,7 @@ func DemoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, route ts := time.Now().Unix() var messages []*api.ChannelMessage - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { // If the transaction is retried ensure we wipe any messages that may have been prepared by previous attempts. messages = make([]*api.ChannelMessage, 0, len(userIDs)) diff --git a/server/core_link.go b/server/core_link.go index e18acff0c..110c77d0c 100644 --- a/server/core_link.go +++ b/server/core_link.go @@ -118,13 +118,7 @@ func LinkDevice(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid return status.Error(codes.InvalidArgument, "Device ID invalid, must be 10-128 bytes.") } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return status.Error(codes.Internal, "Error linking Device ID.") - } - - err = ExecuteInTx(ctx, tx, func() error { + err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { var dbDeviceIDLinkedUser int64 err := tx.QueryRowContext(ctx, "SELECT COUNT(id) FROM user_device WHERE id = $1 AND user_id = $2 LIMIT 1", deviceID, userID).Scan(&dbDeviceIDLinkedUser) if err != nil { diff --git a/server/core_multi.go b/server/core_multi.go index 6194a3270..8b29800de 100644 --- a/server/core_multi.go +++ b/server/core_multi.go @@ -31,13 +31,7 @@ func MultiUpdate(ctx context.Context, logger *zap.Logger, db *sql.DB, metrics Me var storageWriteAcks []*api.StorageObjectAck var walletUpdateResults []*runtime.WalletUpdateResult - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return nil, nil, err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { storageWriteAcks = nil walletUpdateResults = nil diff --git a/server/core_storage.go b/server/core_storage.go index a0f36f3fa..7a2aa210d 100644 --- a/server/core_storage.go +++ b/server/core_storage.go @@ -467,13 +467,7 @@ WHERE func StorageWriteObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, metrics Metrics, authoritativeWrite bool, ops StorageOpWrites) (*api.StorageObjectAcks, codes.Code, error) { var acks []*api.StorageObjectAck - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return nil, codes.Internal, err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { // If the transaction is retried ensure we wipe any acks that may have been prepared by previous attempts. var writeErr error acks, writeErr = storageWriteObjects(ctx, logger, metrics, tx, authoritativeWrite, ops) @@ -645,13 +639,7 @@ func StorageDeleteObjects(ctx context.Context, logger *zap.Logger, db *sql.DB, a // Ensure deletes are processed in a consistent order. sort.Sort(ops) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return codes.Internal, err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { for _, op := range ops { params := []interface{}{op.ObjectID.Collection, op.ObjectID.Key, op.OwnerID} var query string diff --git a/server/core_tournament.go b/server/core_tournament.go index 28315a640..eecf83772 100644 --- a/server/core_tournament.go +++ b/server/core_tournament.go @@ -128,14 +128,8 @@ func TournamentJoin(ctx context.Context, logger *zap.Logger, db *sql.DB, cache L return runtime.ErrTournamentOutsideDuration } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return err - } - var isNewJoin bool - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { query := `INSERT INTO leaderboard_record (leaderboard_id, owner_id, expiry_time, username, num_score, max_num_score) VALUES @@ -558,13 +552,7 @@ func TournamentRecordWrite(ctx context.Context, logger *zap.Logger, db *sql.DB, DO UPDATE SET ` + opSQL + `, num_score = leaderboard_record.num_score + 1, metadata = COALESCE($7, leaderboard_record.metadata), username = COALESCE($3, leaderboard_record.username), update_time = now()` + filterSQL params = append(params, leaderboard.MaxNumScore, scoreAbs, subscoreAbs) - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return nil, err - } - - if err := ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { recordQueryResult, err := tx.ExecContext(ctx, query, params...) if err != nil { var pgErr *pgconn.PgError diff --git a/server/core_unlink.go b/server/core_unlink.go index 629a26e5b..15bf995f0 100644 --- a/server/core_unlink.go +++ b/server/core_unlink.go @@ -99,13 +99,7 @@ func UnlinkDevice(ctx context.Context, logger *zap.Logger, db *sql.DB, id uuid.U return status.Error(codes.InvalidArgument, "A device ID must be supplied.") } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return status.Error(codes.Internal, "Could not unlink Device ID.") - } - - err = ExecuteInTx(ctx, tx, func() error { + err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { res, err := tx.ExecContext(ctx, `DELETE FROM user_device WHERE id = $2 AND user_id = $1 AND (EXISTS (SELECT id FROM users WHERE id = $1 AND (apple_id IS NOT NULL diff --git a/server/core_wallet.go b/server/core_wallet.go index cb270e503..4a1a884a4 100644 --- a/server/core_wallet.go +++ b/server/core_wallet.go @@ -89,13 +89,7 @@ func UpdateWallets(ctx context.Context, logger *zap.Logger, db *sql.DB, updates var results []*runtime.WalletUpdateResult - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Error("Could not begin database transaction.", zap.Error(err)) - return nil, err - } - - if err = ExecuteInTx(ctx, tx, func() error { + if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error { var updateErr error results, updateErr = updateWallets(ctx, logger, tx, updates, updateLedger) if updateErr != nil { diff --git a/server/db.go b/server/db.go index 725ae6087..2b0c3c078 100644 --- a/server/db.go +++ b/server/db.go @@ -32,6 +32,8 @@ import ( var ErrDatabaseDriverMismatch = errors.New("database driver mismatch") +var isCockroach bool + func DbConnect(ctx context.Context, logger *zap.Logger, config Config) (*sql.DB, string) { rawURL := config.GetDatabase().Addresses[0] if !(strings.HasPrefix(rawURL, "postgresql://") || strings.HasPrefix(rawURL, "postgres://")) { @@ -89,6 +91,11 @@ func DbConnect(ctx context.Context, logger *zap.Logger, config Config) (*sql.DB, if err = db.QueryRowContext(pingCtx, "SELECT version()").Scan(&dbVersion); err != nil { logger.Fatal("Error querying database version", zap.Error(err)) } + if strings.Split(dbVersion, " ")[0] == "CockroachDB" { + isCockroach = true + } else { + isCockroach = false + } // Periodically check database hostname for underlying address changes. go func() { @@ -224,13 +231,61 @@ func ExecuteRetryable(fn func() error) error { } // ExecuteInTx runs fn inside tx which should already have begun. -// *WARNING*: Do not execute any statements on the supplied tx before calling this function. -// ExecuteInTx will only retry statements that are performed within the supplied -// closure (fn). Any statements performed on the tx before ExecuteInTx is invoked will *not* -// be re-run if the transaction needs to be retried. -// // fn is subject to the same restrictions as the fn passed to ExecuteTx. -func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { +func ExecuteInTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error { + if isCockroach { + return ExecuteInTxCockroach(ctx, db, fn) + } else { + return ExecuteInTxPostgres(ctx, db, fn) + } +} + +// Retries fn() if transaction commit returned retryable error code +// Every call to fn() happens in its own transaction. On retry previous transaction +// is ROLLBACK'ed and new transaction is opened. +func ExecuteInTxPostgres(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (err error) { + var tx *sql.Tx + defer func() { + if tx != nil { + _ = tx.Rollback() + } + }() + + // Prevent infinite loop (unlikely, but possible) + for i := 0; i < 5; i++ { + if tx, err = db.BeginTx(ctx, nil); err != nil { // Can fail only if undernath connection is broken + tx = nil + return err + } + if err = fn(tx); err == nil { + err = tx.Commit() + } + var pgErr *pgconn.PgError + if errors.As(errorCause(err), &pgErr) && pgErr.Code[:2] == "40" { + // 40XXXX codes are retriable errors + if err = tx.Rollback(); err != nil && err != sql.ErrTxDone { + tx = nil + return err + } + continue + } else { + // Exit on successfull Commit or non retriable error + return err + } + } + // Stop trying after 5 attempts and return last op error + return err +} + +// CockroachDB has it's own way to resolve serialization conflicts. +// It has special optimization for `SAVEPOINT cockroach_restart`, called "retry savepoint", +// which increases transaction priority every time it has to ROLLBACK due to serialization conflicts. +// See: https://www.cockroachlabs.com/docs/stable/advanced-client-side-transaction-retries.html +func ExecuteInTxCockroach(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { // Can fail only if undernath connection is broken + return err + } defer func() { if err == nil { // Ignore commit errors. The tx has already been committed by RELEASE. @@ -246,9 +301,10 @@ func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { return err } - for { + // Prevent infinite loop (unlikely, but possible) + for i := 0; i < 5; i++ { released := false - err = fn() + err = fn(tx) if err == nil { // RELEASE acts like COMMIT in CockroachDB. We use it since it gives us an // opportunity to react to retryable errors, whereas tx.Commit() doesn't. @@ -272,4 +328,6 @@ func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { return newTxnRestartError(retryErr, err) } } + // Stop trying after 5 attempts and return last op error + return err }