Skip to content

Commit

Permalink
db: use context and go-mockgen for login sources (#7041)
Browse files Browse the repository at this point in the history
  • Loading branch information
unknwon committed Jun 10, 2022
1 parent 94059f2 commit 9776bdc
Show file tree
Hide file tree
Showing 14 changed files with 2,303 additions and 535 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.16
require (
github.com/Masterminds/semver/v3 v3.1.1
github.com/bgentry/speakeasy v0.1.0 // indirect
github.com/derision-test/go-mockgen v1.2.0
github.com/editorconfig/editorconfig-core-go/v2 v2.4.4
github.com/fatih/color v1.9.0 // indirect
github.com/go-ldap/ldap/v3 v3.4.3
Expand Down Expand Up @@ -52,8 +53,7 @@ require (
github.com/unknwon/paginater v0.0.0-20170405233947-45e5d631308e
github.com/urfave/cli v1.22.9
golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/net v0.0.0-20220325170049-de3da57026de
golang.org/x/text v0.3.7
gopkg.in/DATA-DOG/go-sqlmock.v2 v2.0.0-20180914054222-c19298f520d0
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
Expand Down
55 changes: 50 additions & 5 deletions go.sum

Large diffs are not rendered by default.

13 changes: 6 additions & 7 deletions internal/db/login_source_files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"gogs.io/gogs/internal/errutil"
)

func Test_loginSourceFiles_GetByID(t *testing.T) {
func TestLoginSourceFiles_GetByID(t *testing.T) {
store := &loginSourceFiles{
sources: []*LoginSource{
{ID: 101},
Expand All @@ -28,14 +29,12 @@ func Test_loginSourceFiles_GetByID(t *testing.T) {

t.Run("source exists", func(t *testing.T) {
source, err := store.GetByID(101)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assert.Equal(t, int64(101), source.ID)
})
}

func Test_loginSourceFiles_Len(t *testing.T) {
func TestLoginSourceFiles_Len(t *testing.T) {
store := &loginSourceFiles{
sources: []*LoginSource{
{ID: 101},
Expand All @@ -45,7 +44,7 @@ func Test_loginSourceFiles_Len(t *testing.T) {
assert.Equal(t, 1, store.Len())
}

func Test_loginSourceFiles_List(t *testing.T) {
func TestLoginSourceFiles_List(t *testing.T) {
store := &loginSourceFiles{
sources: []*LoginSource{
{ID: 101, IsActived: true},
Expand All @@ -65,7 +64,7 @@ func Test_loginSourceFiles_List(t *testing.T) {
})
}

func Test_loginSourceFiles_Update(t *testing.T) {
func TestLoginSourceFiles_Update(t *testing.T) {
store := &loginSourceFiles{
sources: []*LoginSource{
{ID: 101, IsActived: true, IsDefault: true},
Expand Down
75 changes: 40 additions & 35 deletions internal/db/login_sources.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package db

import (
"context"
"fmt"
"strconv"
"time"
Expand All @@ -25,24 +26,24 @@ import (
//
// NOTE: All methods are sorted in alphabetical order.
type LoginSourcesStore interface {
// Create creates a new login source and persist to database.
// It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
Create(opts CreateLoginSourceOpts) (*LoginSource, error)
// Create creates a new login source and persist to database. It returns
// ErrLoginSourceAlreadyExist when a login source with same name already exists.
Create(ctx context.Context, opts CreateLoginSourceOpts) (*LoginSource, error)
// Count returns the total number of login sources.
Count() int64
// DeleteByID deletes a login source by given ID.
// It returns ErrLoginSourceInUse if at least one user is associated with the login source.
DeleteByID(id int64) error
// GetByID returns the login source with given ID.
// It returns ErrLoginSourceNotExist when not found.
GetByID(id int64) (*LoginSource, error)
Count(ctx context.Context) int64
// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
// if at least one user is associated with the login source.
DeleteByID(ctx context.Context, id int64) error
// GetByID returns the login source with given ID. It returns
// ErrLoginSourceNotExist when not found.
GetByID(ctx context.Context, id int64) (*LoginSource, error)
// List returns a list of login sources filtered by options.
List(opts ListLoginSourceOpts) ([]*LoginSource, error)
List(ctx context.Context, opts ListLoginSourceOpts) ([]*LoginSource, error)
// ResetNonDefault clears default flag for all the other login sources.
ResetNonDefault(source *LoginSource) error
// Save persists all values of given login source to database or local file.
// The Updated field is set to current time automatically.
Save(t *LoginSource) error
ResetNonDefault(ctx context.Context, source *LoginSource) error
// Save persists all values of given login source to database or local file. The
// Updated field is set to current time automatically.
Save(ctx context.Context, t *LoginSource) error
}

var LoginSources LoginSourcesStore
Expand All @@ -65,7 +66,7 @@ type LoginSource struct {
File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
}

// NOTE: This is a GORM save hook.
// BeforeSave implements the GORM save hook.
func (s *LoginSource) BeforeSave(_ *gorm.DB) (err error) {
if s.Provider == nil {
return nil
Expand All @@ -74,7 +75,7 @@ func (s *LoginSource) BeforeSave(_ *gorm.DB) (err error) {
return err
}

// NOTE: This is a GORM create hook.
// BeforeCreate implements the GORM create hook.
func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
if s.CreatedUnix == 0 {
s.CreatedUnix = tx.NowFunc().Unix()
Expand All @@ -83,13 +84,13 @@ func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
return nil
}

// NOTE: This is a GORM update hook.
// BeforeUpdate implements the GORM update hook.
func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
s.UpdatedUnix = tx.NowFunc().Unix()
return nil
}

// NOTE: This is a GORM query hook.
// AfterFind implements the GORM query hook.
func (s *LoginSource) AfterFind(_ *gorm.DB) error {
s.Created = time.Unix(s.CreatedUnix, 0).Local()
s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
Expand Down Expand Up @@ -209,8 +210,8 @@ func (err ErrLoginSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists: %v", err.args)
}

func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
func (db *loginSources) Create(ctx context.Context, opts CreateLoginSourceOpts) (*LoginSource, error) {
err := db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
if err == nil {
return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
} else if err != gorm.ErrRecordNotFound {
Expand All @@ -227,12 +228,12 @@ func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error)
if err != nil {
return nil, err
}
return source, db.DB.Create(source).Error
return source, db.WithContext(ctx).Create(source).Error
}

func (db *loginSources) Count() int64 {
func (db *loginSources) Count(ctx context.Context) int64 {
var count int64
db.Model(new(LoginSource)).Count(&count)
db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
return count + int64(db.files.Len())
}

Expand All @@ -249,21 +250,21 @@ func (err ErrLoginSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users: %v", err.args)
}

func (db *loginSources) DeleteByID(id int64) error {
func (db *loginSources) DeleteByID(ctx context.Context, id int64) error {
var count int64
err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
err := db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
}

return db.Where("id = ?", id).Delete(new(LoginSource)).Error
return db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
}

func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
func (db *loginSources) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
source := new(LoginSource)
err := db.Where("id = ?", id).First(source).Error
err := db.WithContext(ctx).Where("id = ?", id).First(source).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return db.files.GetByID(id)
Expand All @@ -278,9 +279,9 @@ type ListLoginSourceOpts struct {
OnlyActivated bool
}

func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
func (db *loginSources) List(ctx context.Context, opts ListLoginSourceOpts) ([]*LoginSource, error) {
var sources []*LoginSource
query := db.Order("id ASC")
query := db.WithContext(ctx).Order("id ASC")
if opts.OnlyActivated {
query = query.Where("is_actived = ?", true)
}
Expand All @@ -292,8 +293,12 @@ func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
return append(sources, db.files.List(opts)...), nil
}

func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
func (db *loginSources) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
err := db.WithContext(ctx).
Model(new(LoginSource)).
Where("id != ?", dflt.ID).
Updates(map[string]interface{}{"is_default": false}).
Error
if err != nil {
return err
}
Expand All @@ -311,9 +316,9 @@ func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
return nil
}

func (db *loginSources) Save(source *LoginSource) error {
func (db *loginSources) Save(ctx context.Context, source *LoginSource) error {
if source.File == nil {
return db.DB.Save(source).Error
return db.WithContext(ctx).Save(source).Error
}

source.File.SetGeneral("name", source.Name)
Expand Down
Loading

0 comments on commit 9776bdc

Please sign in to comment.