Skip to content

Commit

Permalink
ExecuteInTx PostgreSQL version (#1045)
Browse files Browse the repository at this point in the history
* Enforce contract of ExecuteInTx at the API level

Previously ExecuteInTx accepted open transaction, but required
users never to execute any commands on it prior to calling
ExecuteInTx. This API change enforces this contract by making
ExecuteInTx to open transaction internally and pass it to the
callback func.

* Implement PG version of ExecuteInTx which does fewer roundtrips to the Server

PostgreSQL doesn't benefit from SAVEPOINT/ROLLBACK logic like CockroachDB
does. With this change Nakama checks server DB engine and enables CockroachDB
optimization only when necessary.

There are 2 behviour change in the PG version of ExecuteInTx:

- it retries on all "Class 40" (a.k.a retriable) codes, not just
  serialization error:

	40000 	transaction_rollback
	40002 	transaction_integrity_constraint_violation
	40001 	serialization_failure
	40003 	statement_completion_unknown
	40P01 	deadlock_detected

- It doesn't ignore COMMIT result code anymore
  • Loading branch information
redbaron committed Jul 11, 2023
1 parent e08c03e commit 21e3cbd
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 217 deletions.
15 changes: 5 additions & 10 deletions server/console_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import (
"encoding/base64"
"encoding/gob"
"encoding/json"
"regexp"
"strconv"
"strings"

"github.com/gofrs/uuid"
"github.com/heroiclabs/nakama-common/api"
"github.com/heroiclabs/nakama/v3/console"
Expand All @@ -31,9 +35,6 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"regexp"
"strconv"
"strings"
)

var validTrigramFilterRegex = regexp.MustCompile("^%?[^%]{3,}%?$")
Expand Down Expand Up @@ -692,13 +693,7 @@ func (s *ConsoleServer) UpdateAccount(ctx context.Context, in *console.UpdateAcc
return &emptypb.Empty{}, nil
}

tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
s.logger.Error("Could not begin database transaction.", zap.Error(err))
return nil, status.Error(codes.Internal, "An error occurred while trying to update the user.")
}

if err = ExecuteInTx(ctx, tx, func() error {
if err = ExecuteInTx(ctx, s.db, func(tx *sql.Tx) error {
for oldDeviceID, newDeviceID := range in.DeviceIds {
if newDeviceID == "" {
query := `DELETE FROM user_device WHERE id = $2 AND user_id = $1
Expand Down
16 changes: 2 additions & 14 deletions server/console_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,7 @@ func (s *ConsoleServer) DemoteGroupMember(ctx context.Context, in *console.Updat
var message *api.ChannelMessage
ts := time.Now().Unix()

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
query := ""
if myState == 0 {
// Ensure we aren't removing the last superadmin when deleting authoritatively.
Expand Down Expand Up @@ -463,13 +457,7 @@ func (s *ConsoleServer) PromoteGroupMember(ctx context.Context, in *console.Upda
var message *api.ChannelMessage
ts := time.Now().Unix()

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
if uid == caller {
return errors.New("cannot promote self")
}
Expand Down
9 changes: 2 additions & 7 deletions server/console_unlink.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package server

import (
"context"
"database/sql"

"github.com/gofrs/uuid"
"github.com/heroiclabs/nakama/v3/console"
Expand Down Expand Up @@ -96,13 +97,7 @@ func (s *ConsoleServer) UnlinkDevice(ctx context.Context, in *console.UnlinkDevi
return nil, status.Error(codes.InvalidArgument, "Requires a valid device ID.")
}

tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
s.logger.Error("Could not begin database transaction.", zap.Error(err))
return nil, status.Error(codes.Internal, "Could not unlink Device ID.")
}

err = ExecuteInTx(ctx, tx, func() error {
err = ExecuteInTx(ctx, s.db, func(tx *sql.Tx) error {
query := `DELETE FROM user_device WHERE id = $2 AND user_id = $1
AND (EXISTS (SELECT id FROM users WHERE id = $1 AND
(apple_id IS NOT NULL
Expand Down
20 changes: 4 additions & 16 deletions server/core_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,7 @@ WHERE u.id IN (` + strings.Join(statements, ",") + `)`
}

func UpdateAccounts(ctx context.Context, logger *zap.Logger, db *sql.DB, updates []*accountUpdate) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err = ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
updateErr := updateAccounts(ctx, logger, tx, updates)
if updateErr != nil {
return updateErr
Expand Down Expand Up @@ -473,14 +467,8 @@ func ExportAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, userID u
func DeleteAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config, leaderboardRankCache LeaderboardRankCache, sessionRegistry SessionRegistry, sessionCache SessionCache, tracker Tracker, userID uuid.UUID, recorded bool) error {
ts := time.Now().UTC().Unix()

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

var deleted bool
if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
count, err := DeleteUser(ctx, tx, userID)
if err != nil {
logger.Debug("Could not delete user", zap.Error(err), zap.String("user_id", userID.String()))
Expand Down Expand Up @@ -520,11 +508,11 @@ func DeleteAccount(ctx context.Context, logger *zap.Logger, db *sql.DB, config C

if deleted {
// Logout and disconnect.
if err = SessionLogout(config, sessionCache, userID, "", ""); err != nil {
if err := SessionLogout(config, sessionCache, userID, "", ""); err != nil {
return err
}
for _, presence := range tracker.ListPresenceIDByStream(PresenceStream{Mode: StreamModeNotifications, Subject: userID}) {
if err = sessionRegistry.Disconnect(ctx, presence.SessionID, false); err != nil {
if err := sessionRegistry.Disconnect(ctx, presence.SessionID, false); err != nil {
return err
}
}
Expand Down
24 changes: 3 additions & 21 deletions server/core_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,7 @@ func AuthenticateDevice(ctx context.Context, logger *zap.Logger, db *sql.DB, dev
// Create a new account.
userID := uuid.Must(uuid.NewV4()).String()

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return "", "", false, status.Error(codes.Internal, "Error finding or creating user account.")
}

err = ExecuteInTx(ctx, tx, func() error {
err = ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
query := `
INSERT INTO users (id, username, create_time, update_time)
SELECT $1 AS id,
Expand Down Expand Up @@ -848,13 +842,7 @@ func importSteamFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, mes
}

var friendUserIDs []uuid.UUID
tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return status.Error(codes.Internal, "Error importing Steam friends.")
}

err = ExecuteInTx(ctx, tx, func() error {
err = ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
if reset {
if err := resetUserFriends(ctx, tx, userID); err != nil {
logger.Error("Could not reset user friends", zap.Error(err))
Expand Down Expand Up @@ -930,13 +918,7 @@ func importFacebookFriends(ctx context.Context, logger *zap.Logger, db *sql.DB,
}

var friendUserIDs []uuid.UUID
tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return status.Error(codes.Internal, "Error importing Facebook friends.")
}

err = ExecuteInTx(ctx, tx, func() error {
err = ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
if reset {
if err := resetUserFriends(ctx, tx, userID); err != nil {
logger.Error("Could not reset user friends", zap.Error(err))
Expand Down
27 changes: 5 additions & 22 deletions server/core_friend.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/heroiclabs/nakama-common/runtime"
"strconv"
"time"

"github.com/heroiclabs/nakama-common/runtime"

"github.com/gofrs/uuid"
"github.com/heroiclabs/nakama-common/api"
"github.com/jackc/pgtype"
Expand Down Expand Up @@ -226,13 +227,7 @@ func AddFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, messageRout

var notificationToSend map[string]bool

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err = ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
// If the transaction is retried ensure we wipe any notifications that may have been prepared by previous attempts.
notificationToSend = make(map[string]bool)

Expand Down Expand Up @@ -373,13 +368,7 @@ func DeleteFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, currentU
uniqueFriendIDs[fid] = struct{}{}
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err = ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
for id := range uniqueFriendIDs {
if deleteFriendErr := deleteFriend(ctx, logger, tx, currentUser, id); deleteFriendErr != nil {
return deleteFriendErr
Expand Down Expand Up @@ -428,13 +417,7 @@ func BlockFriends(ctx context.Context, logger *zap.Logger, db *sql.DB, currentUs
uniqueFriendIDs[fid] = struct{}{}
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err = ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
for id := range uniqueFriendIDs {
if blockFriendErr := blockFriend(ctx, logger, tx, currentUser, id); blockFriendErr != nil {
return blockFriendErr
Expand Down
72 changes: 9 additions & 63 deletions server/core_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,7 @@ RETURNING id, creator_id, name, description, avatar_url, state, edge_count, lang

var group *api.Group

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return nil, err
}

if err = ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
rows, err := tx.QueryContext(ctx, query, params...)
if err != nil {
var pgErr *pgconn.PgError
Expand Down Expand Up @@ -273,13 +267,7 @@ func DeleteGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, groupID uu
}
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err = ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
return deleteGroup(ctx, logger, tx, groupID)
}); err != nil {
logger.Error("Error deleting group.", zap.Error(err))
Expand Down Expand Up @@ -409,13 +397,7 @@ WHERE (id = $1) AND (disable_time = '1970-01-01 00:00:00 UTC')`
GroupId: group.Id,
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err = ExecuteInTx(ctx, tx, func() error {
if err = ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
if _, err = groupAddUser(ctx, db, tx, uuid.Must(uuid.FromString(group.Id)), userID, state); err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == dbErrorUniqueViolation {
Expand Down Expand Up @@ -524,13 +506,7 @@ func LeaveGroup(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker Tra
GroupId: groupID.String(),
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
query = "DELETE FROM group_edge WHERE (source_id = $1::UUID AND destination_id = $2::UUID) OR (source_id = $2::UUID AND destination_id = $1::UUID)"
// don't need to check affectedRows as we've confirmed the existence of the relationship above
if _, err = tx.ExecContext(ctx, query, groupID, userID); err != nil {
Expand Down Expand Up @@ -641,13 +617,7 @@ func AddGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, router M
ts := time.Now().Unix()
var messages []*api.ChannelMessage

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
// If the transaction is retried ensure we wipe any notifications/messages that may have been prepared by previous attempts.
notifications = make(map[uuid.UUID][]*api.Notification, len(userIDs))
messages = make([]*api.ChannelMessage, 0, len(userIDs))
Expand Down Expand Up @@ -800,13 +770,7 @@ func BanGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker
var messages []*api.ChannelMessage
kicked := make(map[uuid.UUID]struct{}, len(userIDs))

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
// If the transaction is retried ensure we wipe any messages that may have been prepared by previous attempts.
messages = make([]*api.ChannelMessage, 0, len(userIDs))
// Position to use for new banned edges.
Expand Down Expand Up @@ -991,13 +955,7 @@ func KickGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, tracker
var messages []*api.ChannelMessage
kicked := make(map[uuid.UUID]struct{}, len(userIDs))

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
// If the transaction is retried ensure we wipe any messages that may have been prepared by previous attempts.
messages = make([]*api.ChannelMessage, 0, len(userIDs))

Expand Down Expand Up @@ -1172,13 +1130,7 @@ func PromoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, rout
ts := time.Now().Unix()
var messages []*api.ChannelMessage

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
// If the transaction is retried ensure we wipe any messages that may have been prepared by previous attempts.
messages = make([]*api.ChannelMessage, 0, len(userIDs))

Expand Down Expand Up @@ -1303,13 +1255,7 @@ func DemoteGroupUsers(ctx context.Context, logger *zap.Logger, db *sql.DB, route
ts := time.Now().Unix()
var messages []*api.ChannelMessage

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return err
}

if err := ExecuteInTx(ctx, tx, func() error {
if err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
// If the transaction is retried ensure we wipe any messages that may have been prepared by previous attempts.
messages = make([]*api.ChannelMessage, 0, len(userIDs))

Expand Down
8 changes: 1 addition & 7 deletions server/core_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,7 @@ func LinkDevice(ctx context.Context, logger *zap.Logger, db *sql.DB, userID uuid
return status.Error(codes.InvalidArgument, "Device ID invalid, must be 10-128 bytes.")
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Error("Could not begin database transaction.", zap.Error(err))
return status.Error(codes.Internal, "Error linking Device ID.")
}

err = ExecuteInTx(ctx, tx, func() error {
err := ExecuteInTx(ctx, db, func(tx *sql.Tx) error {
var dbDeviceIDLinkedUser int64
err := tx.QueryRowContext(ctx, "SELECT COUNT(id) FROM user_device WHERE id = $1 AND user_id = $2 LIMIT 1", deviceID, userID).Scan(&dbDeviceIDLinkedUser)
if err != nil {
Expand Down
Loading

0 comments on commit 21e3cbd

Please sign in to comment.