Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow detect whether it's in a database transaction for a context.Context #21756

Merged
merged 16 commits into from Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/activities/action.go
Expand Up @@ -572,7 +572,7 @@ func NotifyWatchers(actions ...*Action) error {

// NotifyWatchersActions creates batch of actions for every watcher.
func NotifyWatchersActions(acts []*Action) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/activities/notification.go
Expand Up @@ -142,7 +142,7 @@ func CountNotifications(opts *FindNotificationOptions) (int64, error) {

// CreateRepoTransferNotification creates notification for the user a repository was transferred to
func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_model.Repository) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down Expand Up @@ -185,7 +185,7 @@ func CreateRepoTransferNotification(doer, newOwner *user_model.User, repo *repo_
// for each watcher, or updates it if already exists
// receiverID > 0 just send to receiver, else send to all watcher
func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID, receiverID int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key.go
Expand Up @@ -234,7 +234,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) {
return ErrGPGKeyAccessDenied{doer.ID, key.ID}
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key_add.go
Expand Up @@ -73,7 +73,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
return nil, err
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/gpg_key_verify.go
Expand Up @@ -31,7 +31,7 @@ import (

// VerifyGPGKey marks a GPG key as verified
func VerifyGPGKey(ownerID int64, keyID, token, signature string) (string, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions models/asymkey/ssh_key.go
Expand Up @@ -100,7 +100,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub
return nil, err
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -321,7 +321,7 @@ func PublicKeyIsExternallyManaged(id int64) (bool, error) {
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
func deleteKeysMarkedForDeletion(keys []string) (bool, error) {
// Start session
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return false, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_deploy.go
Expand Up @@ -126,7 +126,7 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey
accessMode = perm.AccessModeWrite
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_principals.go
Expand Up @@ -26,7 +26,7 @@ import (

// AddPrincipalKey adds new principal to database and authorized_principals file.
func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*PublicKey, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_verify.go
Expand Up @@ -15,7 +15,7 @@ import (

// VerifySSHKey marks a SSH key as verified
func VerifySSHKey(ownerID int64, fingerprint, token, signature string) (string, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions models/auth/oauth2.go
Expand Up @@ -201,7 +201,7 @@ type UpdateOAuth2ApplicationOptions struct {

// UpdateOAuth2Application updates an oauth2 application
func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -265,7 +265,7 @@ func deleteOAuth2Application(ctx context.Context, id, userid int64) error {

// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
func DeleteOAuth2Application(id, userid int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/auth/session.go
Expand Up @@ -37,7 +37,7 @@ func ReadSession(key string) (*Session, error) {
Key: key,
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -73,7 +73,7 @@ func DestroySession(key string) error {

// RegenerateSession regenerates a session from the old id
func RegenerateSession(oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/avatars/avatar.go
Expand Up @@ -97,7 +97,7 @@ func saveEmailHash(email string) string {
Hash: emailHash,
}
// OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors
if err := db.WithTx(func(ctx context.Context) error {
if err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
has, err := db.GetEngine(ctx).Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
if has || err != nil {
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
Expand Down
64 changes: 57 additions & 7 deletions models/db/context.go
Expand Up @@ -7,7 +7,9 @@ package db
import (
"context"
"database/sql"
"errors"

"xorm.io/xorm"
"xorm.io/xorm/schemas"
)

Expand Down Expand Up @@ -86,7 +88,11 @@ type Committer interface {
}

// TxContext represents a transaction Context
func TxContext() (*Context, Committer, error) {
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
if InTransaction(parentCtx) {
return nil, nil, ErrAlreadyInTransaction
}

sess := x.NewSession()
if err := sess.Begin(); err != nil {
sess.Close()
Expand All @@ -96,13 +102,32 @@ func TxContext() (*Context, Committer, error) {
return newContext(DefaultContext, sess, true), sess, nil
}

var ErrAlreadyInTransaction = errors.New("database connection has already been in a transaction")
6543 marked this conversation as resolved.
Show resolved Hide resolved

// WithTx represents executing database operations on a transaction
// you can optionally change the context to a parent one
func WithTx(f func(ctx context.Context) error, stdCtx ...context.Context) error {
parentCtx := DefaultContext
if len(stdCtx) != 0 && stdCtx[0] != nil {
// TODO: make sure parent context has no open session
parentCtx = stdCtx[0]
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if InTransaction(parentCtx) {
return ErrAlreadyInTransaction
}

sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}

if err := f(newContext(parentCtx, sess, true)); err != nil {
return err
}

return sess.Commit()
}

// MustTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func MustTx(parentCtx context.Context, f func(ctx context.Context) error) error {
lunny marked this conversation as resolved.
Show resolved Hide resolved
if InTransaction(parentCtx) {
return f(newContext(parentCtx, GetEngine(parentCtx), true))
}

sess := x.NewSession()
Expand Down Expand Up @@ -180,3 +205,28 @@ func EstimateCount(ctx context.Context, bean interface{}) (int64, error) {
}
return rows, err
}

// InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool {
var e Engine
if engined, ok := ctx.(Engined); ok {
e = engined.Engine()
} else {
enginedInterface := ctx.Value(enginedContextKey)
if enginedInterface != nil {
e = enginedInterface.(Engined).Engine()
}
}
if e == nil {
return false
}

switch t := e.(type) {
case *xorm.Engine:
return false
case *xorm.Session:
return t.IsInTx()
default:
return false
}
}
33 changes: 33 additions & 0 deletions models/db/context_test.go
@@ -0,0 +1,33 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package db_test

import (
"context"
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"

"github.com/stretchr/testify/assert"
)

func TestInTransaction(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.False(t, db.InTransaction(db.DefaultContext))
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))

ctx, committer, err := db.TxContext(db.DefaultContext)
assert.NoError(t, err)
defer committer.Close()
assert.True(t, db.InTransaction(ctx))
assert.Error(t, db.WithTx(ctx, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))
}
8 changes: 4 additions & 4 deletions models/db/index_test.go
Expand Up @@ -59,7 +59,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
assert.EqualValues(t, 62, maxIndex)

// commit transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
Expand All @@ -73,7 +73,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
assert.EqualValues(t, 73, maxIndex)

// rollback transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
assert.NoError(t, err)
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestGetNextResourceIndex(t *testing.T) {
assert.EqualValues(t, 2, maxIndex)

// commit transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex)
Expand All @@ -114,7 +114,7 @@ func TestGetNextResourceIndex(t *testing.T) {
assert.EqualValues(t, 3, maxIndex)

// rollback transaction
err = db.WithTx(func(ctx context.Context) error {
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 4, maxIndex)
Expand Down
2 changes: 1 addition & 1 deletion models/git/branches.go
Expand Up @@ -544,7 +544,7 @@ func FindRenamedBranch(repoID int64, from string) (branch *RenamedBranch, exist

// RenameBranch rename a branch
func RenameBranch(repo *repo_model.Repository, from, to string, gitAction func(isDefault bool) error) (err error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/git/branches_test.go
Expand Up @@ -102,7 +102,7 @@ func TestRenameBranch(t *testing.T) {
repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
_isDefault := false

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
defer committer.Close()
assert.NoError(t, err)
assert.NoError(t, git_model.UpdateProtectBranch(ctx, repo1, &git_model.ProtectedBranch{
Expand Down
4 changes: 2 additions & 2 deletions models/git/commit_status.go
Expand Up @@ -94,7 +94,7 @@ func GetNextCommitStatusIndex(repoID int64, sha string) (int64, error) {

// getNextCommitStatusIndex return the next index
func getNextCommitStatusIndex(repoID int64, sha string) (int64, error) {
ctx, commiter, err := db.TxContext()
ctx, commiter, err := db.TxContext(db.DefaultContext)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -297,7 +297,7 @@ func NewCommitStatus(opts NewCommitStatusOptions) error {
return fmt.Errorf("generate commit status index failed: %w", err)
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return fmt.Errorf("NewCommitStatus[repo_id: %d, user_id: %d, sha: %s]: %w", opts.Repo.ID, opts.Creator.ID, opts.SHA, err)
}
Expand Down
6 changes: 3 additions & 3 deletions models/git/lfs.go
Expand Up @@ -137,7 +137,7 @@ var ErrLFSObjectNotExist = db.ErrNotExist{Resource: "LFS Meta object"}
func NewLFSMetaObject(m *LFSMetaObject) (*LFSMetaObject, error) {
var err error

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -185,7 +185,7 @@ func RemoveLFSMetaObjectByOid(repoID int64, oid string) (int64, error) {
return 0, ErrLFSObjectNotExist
}

ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -242,7 +242,7 @@ func LFSObjectIsAssociated(oid string) (bool, error) {

// LFSAutoAssociate auto associates accessible LFSMetaObjects
func LFSAutoAssociate(metas []*LFSMetaObject, user *user_model.User, repoID int64) error {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/git/lfs_lock.go
Expand Up @@ -44,7 +44,7 @@ func cleanPath(p string) string {

// CreateLFSLock creates a new lock.
func CreateLFSLock(repo *repo_model.Repository, lock *LFSLock) (*LFSLock, error) {
dbCtx, committer, err := db.TxContext()
dbCtx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -137,7 +137,7 @@ func CountLFSLockByRepoID(repoID int64) (int64, error) {

// DeleteLFSLockByID deletes a lock by given ID.
func DeleteLFSLockByID(id int64, repo *repo_model.Repository, u *user_model.User, force bool) (*LFSLock, error) {
dbCtx, committer, err := db.TxContext()
dbCtx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion models/issues/assignees.go
Expand Up @@ -64,7 +64,7 @@ func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.U

// ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
func ToggleIssueAssignee(issue *Issue, doer *user_model.User, assigneeID int64) (removed bool, comment *Comment, err error) {
ctx, committer, err := db.TxContext()
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return false, nil, err
}
Expand Down