Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/application_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,16 +611,16 @@ func upload(values map[string]*os.File) (contentType string, buffer bytes.Buffer
for key, r := range values {
var fw io.Writer
if fw, err = w.CreateFormFile(key, r.Name()); err != nil {
return
return contentType, buffer, err
}

if _, err = io.Copy(fw, r); err != nil {
return
return contentType, buffer, err
}
}
contentType = w.FormDataContentType()
w.Close()
return
return contentType, buffer, err
}

func mustOpen(f string) *os.File {
Expand Down
2 changes: 1 addition & 1 deletion api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type UserDatabase interface {
DeleteUserByID(id uint) error
UpdateUser(user *model.User) error
CreateUser(user *model.User) error
CountUser(condition ...interface{}) (int, error)
CountUser(condition ...interface{}) (int64, error)
}

// UserChangeNotifier notifies listeners for user changes.
Expand Down
4 changes: 2 additions & 2 deletions auth/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (s *AuthenticationSuite) assertQueryRequest(key, value string, f fMiddlewar
ctx.Request = httptest.NewRequest("GET", fmt.Sprintf("/?%s=%s", key, value), nil)
f()(ctx)
assert.Equal(s.T(), code, recorder.Code)
return
return ctx
}

func (s *AuthenticationSuite) TestNothingProvided() {
Expand Down Expand Up @@ -217,7 +217,7 @@ func (s *AuthenticationSuite) assertHeaderRequest(key, value string, f fMiddlewa
ctx.Request.Header.Set(key, value)
f()(ctx)
assert.Equal(s.T(), code, recorder.Code)
return
return ctx
}

type fMiddleware func() gin.HandlerFunc
2 changes: 1 addition & 1 deletion database/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"time"

"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

// GetApplicationByToken returns the application for the given token or nil.
Expand Down
2 changes: 1 addition & 1 deletion database/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"time"

"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

// GetClientByID returns the client for the given id or nil.
Expand Down
92 changes: 52 additions & 40 deletions database/database.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
package database

import (
"errors"
"log"
"os"
"path/filepath"
"time"

"github.com/gotify/server/v2/auth/password"
"github.com/gotify/server/v2/mode"
"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql" // enable the mysql dialect.
_ "github.com/jinzhu/gorm/dialects/postgres" // enable the postgres dialect.
_ "github.com/jinzhu/gorm/dialects/sqlite" // enable the sqlite3 dialect.
"github.com/mattn/go-isatty"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)

var mkdirAll = os.MkdirAll
Expand All @@ -19,40 +24,68 @@ var mkdirAll = os.MkdirAll
func New(dialect, connection, defaultUser, defaultPass string, strength int, createDefaultUserIfNotExist bool) (*GormDatabase, error) {
createDirectoryIfSqlite(dialect, connection)

db, err := gorm.Open(dialect, connection)
logLevel := logger.Info
if mode.Get() == mode.Prod {
logLevel = logger.Warn
}

dbLogger := logger.New(log.New(os.Stderr, "\r\n", log.LstdFlags), logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logLevel,
IgnoreRecordNotFoundError: true,
Colorful: isatty.IsTerminal(os.Stderr.Fd()),
})
gormConfig := &gorm.Config{
Logger: dbLogger,
DisableForeignKeyConstraintWhenMigrating: true,
}

var db *gorm.DB
err := errors.New("unsupported dialect: " + dialect)

switch dialect {
case "mysql":
db, err = gorm.Open(mysql.Open(connection), gormConfig)
case "postgres":
db, err = gorm.Open(postgres.Open(connection), gormConfig)
case "sqlite3":
db, err = gorm.Open(sqlite.Open(connection), gormConfig)
}

if err != nil {
return nil, err
}

sqldb, err := db.DB()
if err != nil {
return nil, err
}

// We normally don't need that much connections, so we limit them. F.ex. mysql complains about
// "too many connections", while load testing Gotify.
db.DB().SetMaxOpenConns(10)
sqldb.SetMaxOpenConns(10)

if dialect == "sqlite3" {
// We use the database connection inside the handlers from the http
// framework, therefore concurrent access occurs. Sqlite cannot handle
// concurrent writes, so we limit sqlite to one connection.
// see https://github.com/mattn/go-sqlite3/issues/274
db.DB().SetMaxOpenConns(1)
sqldb.SetMaxOpenConns(1)
}

if dialect == "mysql" {
// Mysql has a setting called wait_timeout, which defines the duration
// after which a connection may not be used anymore.
// The default for this setting on mariadb is 10 minutes.
// See https://github.com/docker-library/mariadb/issues/113
db.DB().SetConnMaxLifetime(9 * time.Minute)
sqldb.SetConnMaxLifetime(9 * time.Minute)
}

if err := db.AutoMigrate(new(model.User), new(model.Application), new(model.Message), new(model.Client), new(model.PluginConf)).Error; err != nil {
if err := db.AutoMigrate(new(model.User), new(model.Application), new(model.Message), new(model.Client), new(model.PluginConf)); err != nil {
return nil, err
}

if err := prepareBlobColumn(dialect, db); err != nil {
return nil, err
}

userCount := 0
userCount := int64(0)
db.Find(new(model.User)).Count(&userCount)
if createDefaultUserIfNotExist && userCount == 0 {
db.Create(&model.User{Name: defaultUser, Pass: password.CreatePassword(defaultPass, strength), Admin: true})
Expand All @@ -61,31 +94,6 @@ func New(dialect, connection, defaultUser, defaultPass string, strength int, cre
return &GormDatabase{DB: db}, nil
}

func prepareBlobColumn(dialect string, db *gorm.DB) error {
blobType := ""
switch dialect {
case "mysql":
blobType = "longblob"
case "postgres":
blobType = "bytea"
}
if blobType != "" {
for _, target := range []struct {
Table interface{}
Column string
}{
{model.Message{}, "extras"},
{model.PluginConf{}, "config"},
{model.PluginConf{}, "storage"},
} {
if err := db.Model(target.Table).ModifyColumn(target.Column, blobType).Error; err != nil {
return err
}
}
}
return nil
}

func createDirectoryIfSqlite(dialect, connection string) {
if dialect == "sqlite3" {
if _, err := os.Stat(filepath.Dir(connection)); os.IsNotExist(err) {
Expand All @@ -103,5 +111,9 @@ type GormDatabase struct {

// Close closes the gorm database connection.
func (d *GormDatabase) Close() {
d.DB.Close()
sqldb, err := d.DB.DB()
if err != nil {
return
}
sqldb.Close()
}
10 changes: 5 additions & 5 deletions database/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package database

import (
"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

// GetMessageByID returns the messages for the given id or nil.
Expand All @@ -27,7 +27,7 @@ func (d *GormDatabase) CreateMessage(message *model.Message) error {
func (d *GormDatabase) GetMessagesByUser(userID uint) ([]*model.Message, error) {
var messages []*model.Message
err := d.DB.Joins("JOIN applications ON applications.user_id = ?", userID).
Where("messages.application_id = applications.id").Order("id desc").Find(&messages).Error
Where("messages.application_id = applications.id").Order("messages.id desc").Find(&messages).Error
if err == gorm.ErrRecordNotFound {
err = nil
}
Expand All @@ -39,7 +39,7 @@ func (d *GormDatabase) GetMessagesByUser(userID uint) ([]*model.Message, error)
func (d *GormDatabase) GetMessagesByUserSince(userID uint, limit int, since uint) ([]*model.Message, error) {
var messages []*model.Message
db := d.DB.Joins("JOIN applications ON applications.user_id = ?", userID).
Where("messages.application_id = applications.id").Order("id desc").Limit(limit)
Where("messages.application_id = applications.id").Order("messages.id desc").Limit(limit)
if since != 0 {
db = db.Where("messages.id < ?", since)
}
Expand All @@ -53,7 +53,7 @@ func (d *GormDatabase) GetMessagesByUserSince(userID uint, limit int, since uint
// GetMessagesByApplication returns all messages from an application.
func (d *GormDatabase) GetMessagesByApplication(tokenID uint) ([]*model.Message, error) {
var messages []*model.Message
err := d.DB.Where("application_id = ?", tokenID).Order("id desc").Find(&messages).Error
err := d.DB.Where("application_id = ?", tokenID).Order("messages.id desc").Find(&messages).Error
if err == gorm.ErrRecordNotFound {
err = nil
}
Expand All @@ -64,7 +64,7 @@ func (d *GormDatabase) GetMessagesByApplication(tokenID uint) ([]*model.Message,
// If since is 0 it will be ignored.
func (d *GormDatabase) GetMessagesByApplicationSince(appID uint, limit int, since uint) ([]*model.Message, error) {
var messages []*model.Message
db := d.DB.Where("application_id = ?", appID).Order("id desc").Limit(limit)
db := d.DB.Where("application_id = ?", appID).Order("messages.id desc").Limit(limit)
if since != 0 {
db = db.Where("messages.id < ?", since)
}
Expand Down
17 changes: 10 additions & 7 deletions database/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import (

"github.com/gotify/server/v2/model"
"github.com/gotify/server/v2/test"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

func TestMigration(t *testing.T) {
Expand All @@ -21,18 +22,20 @@ type MigrationSuite struct {

func (s *MigrationSuite) BeforeTest(suiteName, testName string) {
s.tmpDir = test.NewTmpDir("gotify_migrationsuite")
db, err := gorm.Open("sqlite3", s.tmpDir.Path("test_obsolete.db"))
assert.Nil(s.T(), err)
defer db.Close()
db, err := gorm.Open(sqlite.Open(s.tmpDir.Path("test_obsolete.db")), &gorm.Config{})
assert.NoError(s.T(), err)
sqlDB, err := db.DB()
assert.NoError(s.T(), err)
defer sqlDB.Close()

assert.Nil(s.T(), db.CreateTable(new(model.User)).Error)
assert.Nil(s.T(), db.Migrator().CreateTable(new(model.User)))
assert.Nil(s.T(), db.Create(&model.User{
Name: "test_user",
Admin: true,
}).Error)

// we should not be able to create applications by now
assert.False(s.T(), db.HasTable(new(model.Application)))
assert.False(s.T(), db.Migrator().HasTable(new(model.Application)))
}

func (s *MigrationSuite) AfterTest(suiteName, testName string) {
Expand All @@ -44,7 +47,7 @@ func (s *MigrationSuite) TestMigration() {
assert.Nil(s.T(), err)
defer db.Close()

assert.True(s.T(), db.DB.HasTable(new(model.Application)))
assert.True(s.T(), db.DB.Migrator().HasTable(new(model.Application)))

// a user already exist, not adding a new user
if user, err := db.GetUserByName("admin"); assert.NoError(s.T(), err) {
Expand Down
6 changes: 5 additions & 1 deletion database/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@ package database

// Ping pings the database to verify the connection.
func (d *GormDatabase) Ping() error {
return d.DB.DB().Ping()
sqldb, err := d.DB.DB()
if err != nil {
return err
}
return sqldb.Ping()
}
2 changes: 1 addition & 1 deletion database/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package database

import (
"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

// GetPluginConfByUser gets plugin configurations from a user.
Expand Down
6 changes: 3 additions & 3 deletions database/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package database

import (
"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

// GetUserByName returns the user by the given name or nil.
Expand Down Expand Up @@ -32,8 +32,8 @@ func (d *GormDatabase) GetUserByID(id uint) (*model.User, error) {
}

// CountUser returns the user count which satisfies the given condition.
func (d *GormDatabase) CountUser(condition ...interface{}) (int, error) {
c := -1
func (d *GormDatabase) CountUser(condition ...interface{}) (int64, error) {
c := int64(-1)
handle := d.DB.Model(new(model.User))
if len(condition) == 1 {
handle = handle.Where(condition[0])
Expand Down
6 changes: 3 additions & 3 deletions database/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (s *DatabaseSuite) TestUser() {

adminCount, err := s.db.CountUser("admin = ?", true)
require.NoError(s.T(), err)
assert.Equal(s.T(), 1, adminCount, 1, "there is initially one admin")
assert.Equal(s.T(), int64(1), adminCount, "there is initially one admin")

users, err := s.db.GetUsers()
require.NoError(s.T(), err)
Expand All @@ -33,7 +33,7 @@ func (s *DatabaseSuite) TestUser() {
assert.NotEqual(s.T(), 0, nicories.ID, "on create user a new id should be assigned")
userCount, err := s.db.CountUser()
require.NoError(s.T(), err)
assert.Equal(s.T(), 2, userCount, "two users should exist")
assert.Equal(s.T(), int64(2), userCount, "two users should exist")

user, err = s.db.GetUserByName("nicories")
require.NoError(s.T(), err)
Expand All @@ -60,7 +60,7 @@ func (s *DatabaseSuite) TestUser() {

adminCount, err = s.db.CountUser(&model.User{Admin: true})
require.NoError(s.T(), err)
assert.Equal(s.T(), 2, adminCount, "two admins exist")
assert.Equal(s.T(), int64(2), adminCount, "two admins exist")

require.NoError(s.T(), s.db.DeleteUserByID(tom.ID))
users, err = s.db.GetUsers()
Expand Down
Loading
Loading