Skip to content

Commit

Permalink
implement sqlserver storage using gorm.io/driver/sqlserver
Browse files Browse the repository at this point in the history
  • Loading branch information
wooln committed Oct 24, 2023
1 parent 07b3d70 commit b210ad7
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 11 deletions.
2 changes: 2 additions & 0 deletions client/dtmcli/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ const (
DBTypeMysql = dtmimp.DBTypeMysql
// DBTypePostgres const for driver postgres
DBTypePostgres = dtmimp.DBTypePostgres
// DBTypeSqlServer const for driver SqlServer
DBTypeSqlServer = dtmimp.DBTypeSqlServer

Check failure on line 39 in client/dtmcli/consts.go

View workflow job for this annotation

GitHub Actions / CI

const DBTypeSqlServer should be DBTypeSQLServer
)

// MapSuccess HTTP result of SUCCESS
Expand Down
2 changes: 2 additions & 0 deletions client/dtmcli/dtmimp/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ const (
DBTypeMysql = "mysql"
// DBTypePostgres const for driver postgres
DBTypePostgres = "postgres"
// DBTypeSqlServer const for driver SqlServer
DBTypeSqlServer = "sqlserver"

Check failure on line 40 in client/dtmcli/dtmimp/consts.go

View workflow job for this annotation

GitHub Actions / CI

const DBTypeSqlServer should be DBTypeSQLServer
// DBTypeRedis const for driver redis
DBTypeRedis = "redis"
// Jrpc const for json-rpc
Expand Down
21 changes: 21 additions & 0 deletions client/dtmcli/dtmimp/db_special.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,27 @@ func init() {
dbSpecials[DBTypePostgres] = &postgresDBSpecial{}
}

// TODO sqlserver implement (for go client only, not for dtm server)
type sqlserverDBSpecial struct{}

func (*sqlserverDBSpecial) GetPlaceHoldSQL(sql string) string {
// TODO sqlserver implement
return sql
}

func (*sqlserverDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string {
// TODO sqlserver implement
return ""
}

func (*sqlserverDBSpecial) GetXaSQL(command string, xid string) string {
// TODO sqlserver implement
return ""
}
func init() {
dbSpecials[DBTypeSqlServer] = &sqlserverDBSpecial{}
}

// GetDBSpecial get DBSpecial for currentDBType
func GetDBSpecial(dbType string) DBSpecial {
if dbType == "" {
Expand Down
15 changes: 15 additions & 0 deletions client/dtmcli/dtmimp/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,26 @@ func GetDsn(conf DBConf) string {
conf.User, conf.Password, host, conf.Port, conf.Db),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' search_path=%s port=%d sslmode=disable",
host, conf.User, conf.Password, conf.Db, conf.Schema, conf.Port),
// sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30
"sqlserver": getSqlServerConnectionString(&conf, &host),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn
}

func getSqlServerConnectionString(conf *DBConf, host *string) string {

Check failure on line 238 in client/dtmcli/dtmimp/utils.go

View workflow job for this annotation

GitHub Actions / CI

func getSqlServerConnectionString should be getSQLServerConnectionString
query := url.Values{}
query.Add("database", conf.Db)
u := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(conf.User, conf.Password),
Host: fmt.Sprintf("%s:%d", *host, conf.Port),
// Path: instance, // if connecting to an instance instead of a port
RawQuery: query.Encode(),
}
return u.String()
}

// RespAsErrorByJSONRPC translate json rpc resty response to error
func RespAsErrorByJSONRPC(resp *resty.Response) error {
str := resp.String()
Expand Down
4 changes: 3 additions & 1 deletion dtmsvr/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ const (
BoltDb = "boltdb"
// Postgres is postgres driver
Postgres = "postgres"
// SqlServer is SQL Server driver
SqlServer = "sqlserver"

Check failure on line 24 in dtmsvr/config/config.go

View workflow job for this annotation

GitHub Actions / CI

const SqlServer should be SQLServer
)

// MicroService config type for microservice based grpc
Expand Down Expand Up @@ -65,7 +67,7 @@ type Store struct {

// IsDB checks config driver is mysql or postgres
func (s *Store) IsDB() bool {
return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres
return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres || s.Driver == dtmcli.DBTypeSqlServer
}

// GetDBConf returns db conf info
Expand Down
5 changes: 3 additions & 2 deletions dtmsvr/storage/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ var storeFactorys = map[string]StorageFactory{
return &redis.Store{}
},
},
"mysql": sqlFac,
"postgres": sqlFac,
"mysql": sqlFac,
"postgres": sqlFac,
"sqlserver": sqlFac,
}

// GetStore returns storage.Store
Expand Down
31 changes: 23 additions & 8 deletions dtmsvr/storage/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ func (s *Store) ScanTransGlobalStores(position *string, limit int64, condition s
query = query.Where("trans_type = ?", condition.TransType)
}
if !condition.CreateTimeStart.IsZero() {
query = query.Where("create_time >= ?", condition.CreateTimeStart.Format("2006-01-02 15:04:05"))
query = query.Where("create_time >= ?", condition.CreateTimeStart)
}
if !condition.CreateTimeEnd.IsZero() {
query = query.Where("create_time <= ?", condition.CreateTimeEnd.Format("2006-01-02 15:04:05"))
query = query.Where("create_time <= ?", condition.CreateTimeEnd)
}

dbr := query.Order("id desc").Limit(int(limit)).Find(&globals)
Expand Down Expand Up @@ -105,7 +105,13 @@ func (s *Store) UpdateBranches(branches []storage.TransBranchStore, updates []st
func (s *Store) LockGlobalSaveBranches(gid string, status string, branches []storage.TransBranchStore, branchStart int) {
err := dbGet().Transaction(func(tx *gorm.DB) error {
g := &storage.TransGlobalStore{}
dbr := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g)
var dbr *gorm.DB
// sqlserver sql should be: SELECT * FROM "trans_global" with(RowLock,UpdLock) ,but gorm generates "FOR UPDATE" at the back, raw sql instead.
if conf.Store.Driver == config.SqlServer {
dbr = tx.Raw("SELECT * FROM trans_global with(RowLock,UpdLock) WHERE gid=? and status=? ORDER BY id OFFSET 0 ROW FETCH NEXT 1 ROWS ONLY ", gid, status).First(g)
} else {
dbr = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g)
}
if dbr.Error == nil {
if branchStart == -1 {
dbr = tx.Create(branches)
Expand Down Expand Up @@ -164,11 +170,16 @@ func (s *Store) LockOneGlobalTrans(expireIn time.Duration) *storage.TransGlobalS
where := fmt.Sprintf(`next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted')`, nextCronTime)

order := map[string]string{
dtmimp.DBTypeMysql: `order by rand()`,
dtmimp.DBTypePostgres: `order by random()`,
dtmimp.DBTypeMysql: `order by rand()`,
dtmimp.DBTypePostgres: `order by random()`,
dtmimp.DBTypeSqlServer: `order by rand()`,
}[conf.Store.Driver]

ssql := fmt.Sprintf(`select id from trans_global where %s %s limit 1`, where, order)
ssql := map[string]string{
dtmimp.DBTypeMysql: fmt.Sprintf(`select id from trans_global where %s %s limit 1`, where, order),
dtmimp.DBTypePostgres: fmt.Sprintf(`select id from trans_global where %s %s limit 1`, where, order),
dtmimp.DBTypeSqlServer: fmt.Sprintf(`select top 1 id from trans_global where %s %s`, where, order),
}[conf.Store.Driver]
var id int64
err := db.ToSQLDB().QueryRow(ssql).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
Expand Down Expand Up @@ -198,8 +209,9 @@ func (s *Store) LockOneGlobalTrans(expireIn time.Duration) *storage.TransGlobalS
func (s *Store) ResetCronTime(after time.Duration, limit int64) (succeedCount int64, hasRemaining bool, err error) {
nextCronTime := getTimeStr(int64(after / time.Second))
where := map[string]string{
dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d`, nextCronTime, limit),
dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d )`, nextCronTime, limit),
dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d`, nextCronTime, limit),
dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d )`, nextCronTime, limit),
dtmimp.DBTypeSqlServer: fmt.Sprintf(`id in (select top %d id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') )`, limit, nextCronTime),
}[conf.Store.Driver]

sql := fmt.Sprintf(`UPDATE trans_global SET update_time='%s',next_cron_time='%s' WHERE %s`,
Expand Down Expand Up @@ -317,5 +329,8 @@ func wrapError(err error) error {
}

func getTimeStr(afterSecond int64) string {
if conf.Store.Driver == config.SqlServer {
return dtmutil.GetNextTime(afterSecond).Format(time.RFC3339)
}
return dtmutil.GetNextTime(afterSecond).Format("2006-01-02 15:04:05")
}
6 changes: 6 additions & 0 deletions dtmutil/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ import (
"github.com/dtm-labs/logger"
_ "github.com/go-sql-driver/mysql" // register mysql driver
_ "github.com/lib/pq" // register postgres driver

// _ "github.com/microsoft/go-mssqldb" // Microsoft's package conflicts with gorm's package: panic: sql: Register called twice for driver mssql
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlserver" // register sqlserver driver,
"gorm.io/gorm"
)

Expand All @@ -27,6 +30,9 @@ func getGormDialetor(driver string, dsn string) gorm.Dialector {
if driver == dtmcli.DBTypePostgres {
return postgres.Open(dsn)
}
if driver == dtmcli.DBTypeSqlServer {
return sqlserver.Open(dsn)
}
dtmimp.PanicIf(driver != dtmcli.DBTypeMysql, fmt.Errorf("unknown driver: %s", driver))
return mysql.Open(dsn)
}
Expand Down
4 changes: 4 additions & 0 deletions test/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ func TestMain(m *testing.M) {
conf.Store.User = ""
conf.Store.Password = ""
conf.Store.Port = 6379
} else if tenv == config.SqlServer {
conf.Store.User = "sa"
conf.Store.Password = "p@ssw0rd"
conf.Store.Port = 1433
}
conf.Store.Db = ""
registry.WaitStoreUp()
Expand Down

0 comments on commit b210ad7

Please sign in to comment.