Skip to content

Commit

Permalink
feature 增强的 ShardingAlgorithm 设计与实现 (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stone-afk committed Mar 13, 2023
1 parent 277a881 commit 29ec110
Show file tree
Hide file tree
Showing 26 changed files with 3,654 additions and 851 deletions.
2 changes: 1 addition & 1 deletion .CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
- [eorm: 新增 ShardingSelector 实现](https://github.com/gotomicro/eorm/pull/145)
- [eorm: 基于dns的slave发现](https://github.com/ecodeclub/eorm/pull/152)
- [eorm: 分库分表: Merger抽象与批量查询实现](https://github.com/ecodeclub/eorm/pull/160)

- [eorm: 增强的 ShardingAlgorithm 设计与实现](https://github.com/ecodeclub/eorm/pull/161)

## v0.0.1:
- [Init Project](https://github.com/ecodeclub/eorm/pull/1)
Expand Down
16 changes: 8 additions & 8 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ type Query struct {
// Querier 查询器,代表最基本的查询
type Querier[T any] struct {
core
session
Session
qc *QueryContext
}

// RawQuery 创建一个 Querier 实例
// 泛型参数 T 是目标类型。
// 例如,如果查询 User 的数据, 那么 T 就是 User
func RawQuery[T any](sess session, sql string, args ...any) Querier[T] {
func RawQuery[T any](sess Session, sql string, args ...any) Querier[T] {
return Querier[T]{
core: sess.getCore(),
session: sess,
Session: sess,
qc: &QueryContext{
q: &Query{
SQL: sql,
Expand All @@ -57,10 +57,10 @@ func RawQuery[T any](sess session, sql string, args ...any) Querier[T] {
}
}

func newQuerier[T any](sess session, q *Query, meta *model.TableMeta, typ string) Querier[T] {
func newQuerier[T any](sess Session, q *Query, meta *model.TableMeta, typ string) Querier[T] {
return Querier[T]{
core: sess.getCore(),
session: sess,
Session: sess,
qc: &QueryContext{
q: q,
meta: meta,
Expand All @@ -72,7 +72,7 @@ func newQuerier[T any](sess session, q *Query, meta *model.TableMeta, typ string
// Exec 执行 SQL
func (q Querier[T]) Exec(ctx context.Context) Result {
var handler HandleFunc = func(ctx context.Context, qc *QueryContext) *QueryResult {
res, err := q.session.execContext(ctx, qc.q.SQL, qc.q.Args...)
res, err := q.Session.execContext(ctx, qc.q.SQL, qc.q.Args...)
return &QueryResult{Result: res, Err: err}
}

Expand All @@ -92,7 +92,7 @@ func (q Querier[T]) Exec(ctx context.Context) Result {
// 注意在不同的数据库里面,排序可能会不同
// 在没有查找到数据的情况下,会返回 ErrNoRows
func (q Querier[T]) Get(ctx context.Context) (*T, error) {
res := get[T](ctx, q.session, q.core, q.qc)
res := get[T](ctx, q.Session, q.core, q.qc)
if res.Err != nil {
return nil, res.Err
}
Expand Down Expand Up @@ -278,7 +278,7 @@ func (b *builder) buildIns(is values) error {
}

func (q Querier[T]) GetMulti(ctx context.Context) ([]*T, error) {
res := getMulti[T](ctx, q.session, q.core, q.qc)
res := getMulti[T](ctx, q.Session, q.core, q.qc)
if res.Err != nil {
return nil, res.Err
}
Expand Down
10 changes: 5 additions & 5 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import (
)

type core struct {
ms []Middleware
metaRegistry model.MetaRegistry
dialect dialect.Dialect
valCreator valuer.BasicTypeCreator
ms []Middleware
}

func getHandler[T any](ctx context.Context, sess session, c core, qc *QueryContext) *QueryResult {
func getHandler[T any](ctx context.Context, sess Session, c core, qc *QueryContext) *QueryResult {
rows, err := sess.queryContext(ctx, qc.q.SQL, qc.q.Args...)
if err != nil {
return &QueryResult{Err: err}
Expand All @@ -56,7 +56,7 @@ func getHandler[T any](ctx context.Context, sess session, c core, qc *QueryConte
return &QueryResult{Result: tp}
}

func get[T any](ctx context.Context, sess session, core core, qc *QueryContext) *QueryResult {
func get[T any](ctx context.Context, sess Session, core core, qc *QueryContext) *QueryResult {
var handler HandleFunc = func(ctx context.Context, queryContext *QueryContext) *QueryResult {
return getHandler[T](ctx, sess, core, queryContext)
}
Expand All @@ -67,7 +67,7 @@ func get[T any](ctx context.Context, sess session, core core, qc *QueryContext)
return handler(ctx, qc)
}

func getMultiHandler[T any](ctx context.Context, sess session, c core, qc *QueryContext) *QueryResult {
func getMultiHandler[T any](ctx context.Context, sess Session, c core, qc *QueryContext) *QueryResult {
rows, err := sess.queryContext(ctx, qc.q.SQL, qc.q.Args...)
if err != nil {
return &QueryResult{Err: err}
Expand All @@ -94,7 +94,7 @@ func getMultiHandler[T any](ctx context.Context, sess session, c core, qc *Query
return &QueryResult{Result: res}
}

func getMulti[T any](ctx context.Context, sess session, core core, qc *QueryContext) *QueryResult {
func getMulti[T any](ctx context.Context, sess Session, core core, qc *QueryContext) *QueryResult {
var handler HandleFunc = func(ctx context.Context, queryContext *QueryContext) *QueryResult {
return getMultiHandler[T](ctx, sess, core, queryContext)
}
Expand Down
8 changes: 4 additions & 4 deletions delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ import (
// Deleter builds DELETE query
type Deleter[T any] struct {
builder
session
Session
table interface{}
where []Predicate
}

// NewDeleter 开始构建一个 DELETE 查询
func NewDeleter[T any](sess session) *Deleter[T] {
func NewDeleter[T any](sess Session) *Deleter[T] {
return &Deleter[T]{
builder: builder{
core: sess.getCore(),
buffer: bytebufferpool.Get(),
},
session: sess,
Session: sess,
}
}

Expand Down Expand Up @@ -82,5 +82,5 @@ func (d *Deleter[T]) Exec(ctx context.Context) Result {
if err != nil {
return Result{err: err}
}
return newQuerier[T](d.session, query, d.meta, DELETE).Exec(ctx)
return newQuerier[T](d.Session, query, d.meta, DELETE).Exec(ctx)
}
8 changes: 4 additions & 4 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ import (
// More details check Build function
type Inserter[T any] struct {
builder
session
Session
columns []string
values []*T
ignorePK bool
}

// NewInserter 开始构建一个 INSERT 查询
func NewInserter[T any](sess session) *Inserter[T] {
func NewInserter[T any](sess Session) *Inserter[T] {
return &Inserter[T]{
builder: builder{
core: sess.getCore(),
buffer: bytebufferpool.Get(),
},
session: sess,
Session: sess,
}
}

Expand Down Expand Up @@ -116,7 +116,7 @@ func (i *Inserter[T]) Exec(ctx context.Context) Result {
if err != nil {
return Result{err: err}
}
return newQuerier[T](i.session, query, i.meta, INSERT).Exec(ctx)
return newQuerier[T](i.Session, query, i.meta, INSERT).Exec(ctx)
}

func (i *Inserter[T]) buildColumns() ([]*model.ColumnMeta, error) {
Expand Down
15 changes: 8 additions & 7 deletions internal/errs/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ var (
ErrCombinationIsNotStruct = errors.New("eorm: 不支持的组合类型,eorm 只支持结构体组合")
ErrMissingShardingKey = errors.New("eorm: sharding key 未设置")
ErrOnlyResultOneQuery = errors.New("eorm: 只能生成一个 SQL")
ErrNotGenShardingQuery = errors.New("eorm: 未生成 sharding query")
ErrUnsupportedTooComplexQuery = errors.New("eorm: 暂未支持太复杂的查询")
ErrExcShardingAlgorithm = errors.New("eorm: 执行 sharding algorithm 出错")
ErrCtxGetDBName = errors.New("eorm: ctx 获取目标 dbName 出错")
ErrNotFoundTargetDB = errors.New("eorm: 未发现目标 DB")
ErrNotFoundTargetTable = errors.New("eorm: 未发现目标 Table")
ErrSlaveNotFound = errors.New("eorm: slave不存在")
ErrMergerEmptyRows = errors.New("eorm: sql.Rows列表为空")
ErrMergerRowsIsNull = errors.New("eorm: sql.Rows列表中有元素为nil")
ErrMergerEmptyRows = errors.New("eorm: sql.Rows 列表为空")
ErrMergerRowsIsNull = errors.New("eorm: sql.Rows 列表中有元素为 nil")
ErrNotFoundTargetDataSource = errors.New("eorm: 未发现目标 data source")
ErrNotFoundTargetDB = errors.New("eorm: 未发现目标 DB")
ErrRepeatedSetDB = errors.New("eorm: 重复设置 DB")
ErrNotGenShardingQuery = errors.New("eorm: 未生成 sharding query")

// ErrExcShardingAlgorithm = errors.New("eorm: 执行 sharding algorithm 出错")
)

func NewFieldConflictError(field string) error {
Expand Down
96 changes: 65 additions & 31 deletions internal/integration/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"log"
"time"

Expand All @@ -28,7 +29,8 @@ import (
"github.com/ecodeclub/eorm/internal/slaves/dns"

"github.com/ecodeclub/eorm"
"github.com/ecodeclub/eorm/internal/model"
"github.com/ecodeclub/eorm/internal/sharding"
"github.com/ecodeclub/eorm/internal/sharding/datasource"
"github.com/ecodeclub/eorm/internal/slaves"
"github.com/ecodeclub/eorm/internal/slaves/roundrobin"
"github.com/stretchr/testify/suite"
Expand All @@ -53,20 +55,29 @@ func (s *Suite) SetupSuite() {
s.orm = orm
}

type clusterDrivers struct {
clDrivers []*clusterDriver
}

type clusterDriver struct {
msDrivers []*masterSalvesDriver
}

type masterSalvesDriver struct {
masterdsn string
slavedsns []string
}

type ShardingSuite struct {
suite.Suite
slaves slaves.Slaves
driver string
tbSet map[string]bool
driverMap map[string]*masterSalvesDriver
shardingDB *eorm.ShardingDB
dbSf model.ShardingAlgorithm
tableSf model.ShardingAlgorithm
slaves slaves.Slaves
clusters *clusterDrivers
shardingDB *eorm.ShardingDB
algorithm sharding.Algorithm
dataSources map[string]sharding.DataSource
driver string
DBPattern string
DsPattern string
}

func (s *ShardingSuite) openDB(dvr, dsn string) (*sql.DB, error) {
Expand All @@ -81,31 +92,40 @@ func (s *ShardingSuite) openDB(dvr, dsn string) (*sql.DB, error) {
}

func (s *ShardingSuite) initDB() (*eorm.ShardingDB, error) {
masterSlaveDBs := make(map[string]*eorm.MasterSlavesDB, 8)
for k, v := range s.driverMap {
master, err := s.openDB(s.driver, v.masterdsn)
if err != nil {
return nil, err
}
ss := make([]*sql.DB, 0, len(v.slavedsns))
for _, slavedsn := range v.slavedsns {
slave, err := s.openDB(s.driver, slavedsn)
clDrivers := s.clusters.clDrivers
sourceMap := make(map[string]sharding.DataSource, len(clDrivers))
for i, cluster := range clDrivers {
msMap := make(map[string]*eorm.MasterSlavesDB, 8)
for j, d := range cluster.msDrivers {
master, err := s.openDB(s.driver, d.masterdsn)
if err != nil {
return nil, err
}
ss = append(ss, slave)
}
sl, err := roundrobin.NewSlaves(ss...)
require.NoError(s.T(), err)
s.slaves = newTestSlaves(sl)
masterSlaveDB, err := eorm.OpenMasterSlaveDB(
s.driver, master, eorm.MasterSlaveWithSlaves(s.slaves))
if err != nil {
return nil, err
ss := make([]*sql.DB, 0, len(d.slavedsns))
for _, slavedsn := range d.slavedsns {
slave, err := s.openDB(s.driver, slavedsn)
if err != nil {
return nil, err
}
ss = append(ss, slave)
}
sl, err := roundrobin.NewSlaves(ss...)
require.NoError(s.T(), err)
s.slaves = &testBaseSlaves{Slaves: sl}
masterSlaveDB, err := eorm.OpenMasterSlaveDB(
s.driver, master, eorm.MasterSlaveWithSlaves(s.slaves))
if err != nil {
return nil, err
}
dbName := fmt.Sprintf(s.DBPattern, j)
msMap[dbName] = masterSlaveDB
}
masterSlaveDBs[k] = masterSlaveDB
sourceName := fmt.Sprintf(s.DsPattern, i)
sourceMap[sourceName] = eorm.OpenClusterDB(msMap)
}
return eorm.OpenShardingDB(s.driver, masterSlaveDBs, eorm.ShardingDBOptionWithTables(s.tbSet))
s.dataSources = sourceMap
dataSource := datasource.NewShardingDataSource(sourceMap)
return eorm.OpenShardingDB(s.driver, dataSource)
}

func (s *ShardingSuite) SetupSuite() {
Expand Down Expand Up @@ -162,15 +182,29 @@ func (s *MasterSlaveSuite) initDb() (*eorm.MasterSlavesDB, error) {
// return slave, err
//}

type testSlaves struct {
type testBaseSlaves struct {
slaves.Slaves
}

func (s *testBaseSlaves) Next(ctx context.Context) (slaves.Slave, error) {
slave, err := s.Slaves.Next(ctx)
if err != nil {
return slave, err
}
return slave, err
}

type testSlaves struct {
*testBaseSlaves
ch chan string
}

func newTestSlaves(s slaves.Slaves) *testSlaves {
return &testSlaves{
Slaves: s,
ch: make(chan string, 1),
testBaseSlaves: &testBaseSlaves{
Slaves: s,
},
ch: make(chan string, 1),
}
}

Expand Down
30 changes: 10 additions & 20 deletions internal/integration/select_masterslave_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ func (s *MasterSlaveSelectTestSuite) TestMasterSlave() {
},
},
// TODO 从库测试目前有查不到数据的bug
//{
// name: "query use slave",
// i: eorm.NewSelector[test.SimpleStruct](s.orm).Where(eorm.C("Id").LT(4)),
// wantSlave: "0",
// wantRes: s.data,
// ctx: func() context.Context {
// return context.Background()
// },
//},
{
name: "query use slave",
i: eorm.NewSelector[test.SimpleStruct](s.orm).Where(eorm.C("Id").LT(4)),
wantSlave: "0",
wantRes: s.data,
ctx: func() context.Context {
return context.Background()
},
},
}
for _, tc := range testcases {
s.T().Run(tc.name, func(t *testing.T) {
Expand Down Expand Up @@ -150,6 +150,7 @@ func (s *MasterSlaveDNSTestSuite) TestDNSMasterSlave() {
wantSlave string
ctx func() context.Context
}{
// TODO 从库测试目前有查不到数据的bug
{
name: "get slave with dns",
i: eorm.NewSelector[test.SimpleStruct](s.orm).Where(eorm.C("Id").LT(4)),
Expand Down Expand Up @@ -179,14 +180,3 @@ func (s *MasterSlaveDNSTestSuite) TestDNSMasterSlave() {
})
}
}

type Hash struct {
// 有占位符就是要分集群,没有就不分
DatasoucePattern string
// 有占位符就分库,没有就不分
DatabasePattern string
// 有占位符就分表,没有就不分
TablePattern string

Base int
}

0 comments on commit 29ec110

Please sign in to comment.