Skip to content

Commit

Permalink
Even more db.DefaultContext refactor (#27352)
Browse files Browse the repository at this point in the history
Part of #27065

---------

Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
Co-authored-by: delvh <dev.lh@web.de>
  • Loading branch information
3 people committed Oct 3, 2023
1 parent 08507e2 commit cc5df26
Show file tree
Hide file tree
Showing 97 changed files with 298 additions and 294 deletions.
2 changes: 1 addition & 1 deletion models/activities/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ func activityQueryCondition(ctx context.Context, opts GetFeedsOptions) (builder.
}

if opts.RequestedTeam != nil {
env := organization.OrgFromUser(opts.RequestedUser).AccessibleTeamReposEnv(opts.RequestedTeam)
env := organization.OrgFromUser(opts.RequestedUser).AccessibleTeamReposEnv(ctx, opts.RequestedTeam)
teamRepoIDs, err := env.RepoIDs(1, opts.RequestedUser.NumRepos)
if err != nil {
return nil, fmt.Errorf("GetTeamRepositories: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion models/activities/statistic.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type IssueByRepositoryCount struct {
func GetStatistic(ctx context.Context) (stats Statistic) {
e := db.GetEngine(ctx)
stats.Counter.User = user_model.CountUsers(ctx, nil)
stats.Counter.Org, _ = organization.CountOrgs(organization.FindOrgOptions{IncludePrivate: true})
stats.Counter.Org, _ = organization.CountOrgs(ctx, organization.FindOrgOptions{IncludePrivate: true})
stats.Counter.PublicKey, _ = e.Count(new(asymkey_model.PublicKey))
stats.Counter.Repo, _ = repo_model.CountRepositories(ctx, repo_model.CountRepositoryOptions{})
stats.Counter.Watch, _ = e.Count(new(repo_model.Watch))
Expand Down
2 changes: 1 addition & 1 deletion models/issues/assignees_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestUpdateAssignee(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())

// Fake issue with assignees
issue, err := issues_model.GetIssueWithAttrsByID(1)
issue, err := issues_model.GetIssueWithAttrsByID(db.DefaultContext, 1)
assert.NoError(t, err)

// Assign multiple users
Expand Down
4 changes: 2 additions & 2 deletions models/issues/comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,12 @@ func (c *Comment) LoadDepIssueDetails(ctx context.Context) (err error) {
}

// LoadTime loads the associated time for a CommentTypeAddTimeManual
func (c *Comment) LoadTime() error {
func (c *Comment) LoadTime(ctx context.Context) error {
if c.Time != nil || c.TimeID == 0 {
return nil
}
var err error
c.Time, err = GetTrackedTimeByID(c.TimeID)
c.Time, err = GetTrackedTimeByID(ctx, c.TimeID)
return err
}

Expand Down
30 changes: 15 additions & 15 deletions models/issues/issue.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,12 @@ func (issue *Issue) IsTimetrackerEnabled(ctx context.Context) bool {
}

// GetPullRequest returns the issue pull request
func (issue *Issue) GetPullRequest() (pr *PullRequest, err error) {
func (issue *Issue) GetPullRequest(ctx context.Context) (pr *PullRequest, err error) {
if !issue.IsPull {
return nil, fmt.Errorf("Issue is not a pull request")
}

pr, err = GetPullRequestByIssueID(db.DefaultContext, issue.ID)
pr, err = GetPullRequestByIssueID(ctx, issue.ID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -369,9 +369,9 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) {
}

// GetIsRead load the `IsRead` field of the issue
func (issue *Issue) GetIsRead(userID int64) error {
func (issue *Issue) GetIsRead(ctx context.Context, userID int64) error {
issueUser := &IssueUser{IssueID: issue.ID, UID: userID}
if has, err := db.GetEngine(db.DefaultContext).Get(issueUser); err != nil {
if has, err := db.GetEngine(ctx).Get(issueUser); err != nil {
return err
} else if !has {
issue.IsRead = false
Expand All @@ -382,9 +382,9 @@ func (issue *Issue) GetIsRead(userID int64) error {
}

// APIURL returns the absolute APIURL to this issue.
func (issue *Issue) APIURL() string {
func (issue *Issue) APIURL(ctx context.Context) string {
if issue.Repo == nil {
err := issue.LoadRepo(db.DefaultContext)
err := issue.LoadRepo(ctx)
if err != nil {
log.Error("Issue[%d].APIURL(): %v", issue.ID, err)
return ""
Expand Down Expand Up @@ -479,9 +479,9 @@ func (issue *Issue) GetLastEventLabel() string {
}

// GetLastComment return last comment for the current issue.
func (issue *Issue) GetLastComment() (*Comment, error) {
func (issue *Issue) GetLastComment(ctx context.Context) (*Comment, error) {
var c Comment
exist, err := db.GetEngine(db.DefaultContext).Where("type = ?", CommentTypeComment).
exist, err := db.GetEngine(ctx).Where("type = ?", CommentTypeComment).
And("issue_id = ?", issue.ID).Desc("created_unix").Get(&c)
if err != nil {
return nil, err
Expand Down Expand Up @@ -543,12 +543,12 @@ func GetIssueByID(ctx context.Context, id int64) (*Issue, error) {
}

// GetIssueWithAttrsByID returns an issue with attributes by given ID.
func GetIssueWithAttrsByID(id int64) (*Issue, error) {
issue, err := GetIssueByID(db.DefaultContext, id)
func GetIssueWithAttrsByID(ctx context.Context, id int64) (*Issue, error) {
issue, err := GetIssueByID(ctx, id)
if err != nil {
return nil, err
}
return issue, issue.LoadAttributes(db.DefaultContext)
return issue, issue.LoadAttributes(ctx)
}

// GetIssuesByIDs return issues with the given IDs.
Expand Down Expand Up @@ -600,8 +600,8 @@ func GetParticipantsIDsByIssueID(ctx context.Context, issueID int64) ([]int64, e
}

// IsUserParticipantsOfIssue return true if user is participants of an issue
func IsUserParticipantsOfIssue(user *user_model.User, issue *Issue) bool {
userIDs, err := issue.GetParticipantIDsByIssue(db.DefaultContext)
func IsUserParticipantsOfIssue(ctx context.Context, user *user_model.User, issue *Issue) bool {
userIDs, err := issue.GetParticipantIDsByIssue(ctx)
if err != nil {
log.Error(err.Error())
return false
Expand Down Expand Up @@ -894,8 +894,8 @@ func IsErrIssueMaxPinReached(err error) bool {
}

// InsertIssues insert issues to database
func InsertIssues(issues ...*Issue) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func InsertIssues(ctx context.Context, issues ...*Issue) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions models/issues/issue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestIssueAPIURL(t *testing.T) {
err := issue.LoadAttributes(db.DefaultContext)

assert.NoError(t, err)
assert.Equal(t, "https://try.gitea.io/api/v1/repos/user2/repo1/issues/1", issue.APIURL())
assert.Equal(t, "https://try.gitea.io/api/v1/repos/user2/repo1/issues/1", issue.APIURL(db.DefaultContext))
}

func TestGetIssuesByIDs(t *testing.T) {
Expand Down Expand Up @@ -477,7 +477,7 @@ func assertCreateIssues(t *testing.T, isPull bool) {
Labels: []*issues_model.Label{label},
Reactions: []*issues_model.Reaction{reaction},
}
err := issues_model.InsertIssues(is)
err := issues_model.InsertIssues(db.DefaultContext, is)
assert.NoError(t, err)

i := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{Title: title})
Expand Down
2 changes: 1 addition & 1 deletion models/issues/issue_watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func CheckIssueWatch(ctx context.Context, user *user_model.User, issue *Issue) (
if err != nil {
return false, err
}
return repo_model.IsWatchMode(w.Mode) || IsUserParticipantsOfIssue(user, issue), nil
return repo_model.IsWatchMode(w.Mode) || IsUserParticipantsOfIssue(ctx, user, issue), nil
}

// GetIssueWatchersIDs returns IDs of subscribers or explicit unsubscribers to a given issue id
Expand Down
2 changes: 1 addition & 1 deletion models/issues/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ func ParseCodeOwnersLine(ctx context.Context, tokens []string) (*CodeOwnerRule,
warnings = append(warnings, fmt.Sprintf("incorrect codeowner organization: %s", user))
continue
}
teams, err := org.LoadTeams()
teams, err := org.LoadTeams(ctx)
if err != nil {
warnings = append(warnings, fmt.Sprintf("incorrect codeowner team: %s", user))
continue
Expand Down
18 changes: 9 additions & 9 deletions models/issues/tracked_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ func addTime(ctx context.Context, user *user_model.User, issue *Issue, amount in
}

// TotalTimesForEachUser returns the spent time in seconds for each user by an issue
func TotalTimesForEachUser(options *FindTrackedTimesOptions) (map[*user_model.User]int64, error) {
trackedTimes, err := GetTrackedTimes(db.DefaultContext, options)
func TotalTimesForEachUser(ctx context.Context, options *FindTrackedTimesOptions) (map[*user_model.User]int64, error) {
trackedTimes, err := GetTrackedTimes(ctx, options)
if err != nil {
return nil, err
}
Expand All @@ -213,7 +213,7 @@ func TotalTimesForEachUser(options *FindTrackedTimesOptions) (map[*user_model.Us
totalTimes := make(map[*user_model.User]int64)
// Fetching User and making time human readable
for userID, total := range totalTimesByUser {
user, err := user_model.GetUserByID(db.DefaultContext, userID)
user, err := user_model.GetUserByID(ctx, userID)
if err != nil {
if user_model.IsErrUserNotExist(err) {
continue
Expand All @@ -226,8 +226,8 @@ func TotalTimesForEachUser(options *FindTrackedTimesOptions) (map[*user_model.Us
}

// DeleteIssueUserTimes deletes times for issue
func DeleteIssueUserTimes(issue *Issue, user *user_model.User) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func DeleteIssueUserTimes(ctx context.Context, issue *Issue, user *user_model.User) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -265,8 +265,8 @@ func DeleteIssueUserTimes(issue *Issue, user *user_model.User) error {
}

// DeleteTime delete a specific Time
func DeleteTime(t *TrackedTime) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func DeleteTime(ctx context.Context, t *TrackedTime) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -315,9 +315,9 @@ func deleteTime(ctx context.Context, t *TrackedTime) error {
}

// GetTrackedTimeByID returns raw TrackedTime without loading attributes by id
func GetTrackedTimeByID(id int64) (*TrackedTime, error) {
func GetTrackedTimeByID(ctx context.Context, id int64) (*TrackedTime, error) {
time := new(TrackedTime)
has, err := db.GetEngine(db.DefaultContext).ID(id).Get(time)
has, err := db.GetEngine(ctx).ID(id).Get(time)
if err != nil {
return nil, err
} else if !has {
Expand Down
8 changes: 4 additions & 4 deletions models/issues/tracked_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ func TestGetTrackedTimes(t *testing.T) {
func TestTotalTimesForEachUser(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())

total, err := issues_model.TotalTimesForEachUser(&issues_model.FindTrackedTimesOptions{IssueID: 1})
total, err := issues_model.TotalTimesForEachUser(db.DefaultContext, &issues_model.FindTrackedTimesOptions{IssueID: 1})
assert.NoError(t, err)
assert.Len(t, total, 1)
for user, time := range total {
assert.EqualValues(t, 1, user.ID)
assert.EqualValues(t, 400, time)
}

total, err = issues_model.TotalTimesForEachUser(&issues_model.FindTrackedTimesOptions{IssueID: 2})
total, err = issues_model.TotalTimesForEachUser(db.DefaultContext, &issues_model.FindTrackedTimesOptions{IssueID: 2})
assert.NoError(t, err)
assert.Len(t, total, 2)
for user, time := range total {
Expand All @@ -103,15 +103,15 @@ func TestTotalTimesForEachUser(t *testing.T) {
}
}

total, err = issues_model.TotalTimesForEachUser(&issues_model.FindTrackedTimesOptions{IssueID: 5})
total, err = issues_model.TotalTimesForEachUser(db.DefaultContext, &issues_model.FindTrackedTimesOptions{IssueID: 5})
assert.NoError(t, err)
assert.Len(t, total, 1)
for user, time := range total {
assert.EqualValues(t, 2, user.ID)
assert.EqualValues(t, 1, time)
}

total, err = issues_model.TotalTimesForEachUser(&issues_model.FindTrackedTimesOptions{IssueID: 4})
total, err = issues_model.TotalTimesForEachUser(db.DefaultContext, &issues_model.FindTrackedTimesOptions{IssueID: 4})
assert.NoError(t, err)
assert.Len(t, total, 2)
}
2 changes: 1 addition & 1 deletion models/org_team.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ func AddTeamMember(ctx context.Context, team *organization.Team, userID int64) e
return err
}

if err := organization.AddOrgUser(team.OrgID, userID); err != nil {
if err := organization.AddOrgUser(ctx, team.OrgID, userID); err != nil {
return err
}

Expand Down

0 comments on commit cc5df26

Please sign in to comment.