-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.go
289 lines (254 loc) · 7.13 KB
/
data.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
package dac
import (
"fmt"
"gorm.io/gorm"
"reflect"
"runtime"
"strings"
)
// DataAccess 数据访问接口
type DataAccess interface {
Table(db *gorm.DB, name string) *gorm.DB
Find(db *gorm.DB, out interface{}) *gorm.DB
First(db *gorm.DB, out interface{}) *gorm.DB
Last(db *gorm.DB, out interface{}) *gorm.DB
Count(db *gorm.DB, count *int64) *gorm.DB
Select(db *gorm.DB, fields []Field) *gorm.DB
Limit(db *gorm.DB, page, pageSize int64) *gorm.DB
Group(db *gorm.DB, group string) *gorm.DB
Order(db *gorm.DB, order string) *gorm.DB
}
// 定义全局 map
var registeredDataAccess map[DBType]DataAccess
// RegisterDatabase 注册不同数据库类型的方法
func RegisterDatabase(dbType DBType, dataAccess DataAccess) {
if registeredDataAccess == nil {
registeredDataAccess = make(map[DBType]DataAccess)
}
registeredDataAccess[dbType] = dataAccess
}
// GetDataAccess 方法根据外部传入的数据库类型执行相应的操作
func GetDataAccess(dbType DBType) DataAccess {
// getDataAccess 函数用于根据 dbType 获取对应的 dataAccess
if dat, ok := registeredDataAccess[dbType]; ok {
return dat
}
return nil
}
// Database 结构体定义
type Database struct {
db *gorm.DB
dbType DBType
da DataAccess
err error
}
// NewDatabase 函数用于创建数据库实例
func NewDatabase(dbType DBType) *Database {
return &Database{da: GetDataAccess(dbType), dbType: dbType}
}
// Use 传入 db
func (d *Database) Use(db *gorm.DB) *Database {
d.db = db
return d
}
func (d *Database) useSourceDB(db *gorm.DB) *Database {
d.db = db
return d
}
func (d *Database) Table(name string) *Database {
return d.useSourceDB(d.db.Table(name))
}
// AutoMigrate 创建表
// AutoMigrate 创建表
func (d *Database) AutoMigrate(dst ...interface{}) error {
//判断是否为支持的数据类型,如果不支持则返回错误
for _, v := range dst {
if err := autoMigrateStruct(d.dbType, reflect.TypeOf(v).Elem()); err != nil {
return err
}
}
return d.db.AutoMigrate(dst...)
}
// autoMigrateStruct 递归解析结构体
func autoMigrateStruct(dbType DBType, t reflect.Type) error {
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// 解析内嵌结构体
if field.Anonymous && field.Type.Kind() == reflect.Struct {
if err := autoMigrateStruct(dbType, field.Type); err != nil {
return err
}
continue
}
tag := field.Tag.Get("gorm")
if tag == "-" {
continue
}
typeValue := extractTypeFromGormTag(tag)
if typeValue != "" {
if !IsDatabaseTypeSupported(typeValue) {
return fmt.Errorf("Field %s of struct %s has unsupported tag type %s\n", field.Name, t.Name(), typeValue)
}
newTag := "gorm:" + ReplaceFieldType(dbType, typeValue)
field.Tag = reflect.StructTag(newTag)
} else {
fieldType := field.Type.Kind().String()
if field.Type.Kind() == reflect.Slice || field.Tag.Get("dac") == "-" {
continue
}
if field.Type.Kind() == reflect.Struct {
fieldType = field.Type.String()
}
// 如果没有在 gorm 标签中指定类型,则检查字段类型是否符合常量中的类型
if !IsConstantTypeSupported(strings.ToLower(fieldType)) {
return fmt.Errorf("Field %s of struct %s has unsupported type %s\n", field.Name, t.Name(), field.Type)
}
}
}
return nil
}
// unaliasType 递归展开类型别名
func unaliasType(t reflect.Type) reflect.Type {
if t.Kind() != reflect.Ptr && t.Kind() != reflect.Interface {
return t
}
return unaliasType(t.Elem())
}
// Where 构建查询条件
func (d *Database) Where(buildOption *BuilderOption) *Database {
d.db = buildWhereConditions(d.db, d.dbType, buildOption)
return d
}
// Find 查询
func (d *Database) Find(out interface{}) *Database {
return d.useSourceDB(d.db.Find(out))
}
// Create 创建
func (d *Database) Create(out interface{}) *Database {
return d.useSourceDB(d.db.Create(out))
}
func (d *Database) Save(out interface{}) *Database {
return d.useSourceDB(d.db.Save(out))
}
// Updates 根据 `struct` 更新属性,只会更新非零值的字段
func (d *Database) Updates(out interface{}) *Database {
return d.useSourceDB(d.db.Updates(out))
}
// Update 更新单个列
func (d *Database) Update(column string, value interface{}) *Database {
return d.useSourceDB(d.db.Update(column, value))
}
// Delete 删除
func (d *Database) Delete(out interface{}) *Database {
return d.useSourceDB(d.db.Delete(out))
}
// Having having条件查询
func (d *Database) Having(builder *ConditionBuilder) *Database {
err := addHavingConditions(d.db, d.dbType, builder)
if err != nil {
d.err = err
}
return d
}
// Scan 将数据输出到指定的结构体
func (d *Database) Scan(out interface{}) *Database {
return d.useSourceDB(d.db.Scan(out))
}
// First 查询第一条
func (d *Database) First(out interface{}) *Database {
return d
}
// Last 查询最后一条
func (d *Database) Last(out interface{}) *Database {
return d.useSourceDB(d.db.Last(out))
}
// Count 查询数量
func (d *Database) Count(count *int64) *Database {
return d.useSourceDB(d.db.Count(count))
}
// Joins 连接查询
func (d *Database) Joins(query string, args ...interface{}) *Database {
d.useSourceDB(d.db.Joins(query, args...))
return d
}
func (d *Database) Join(tableWithAlias, condition string) *Database {
d.useSourceDB(d.db.Joins("JOIN " + tableWithAlias + " on " + condition))
return d
}
func (d *Database) LeftJoin(tableWithAlias, condition string) *Database {
d.useSourceDB(d.db.Joins("LEFT JOIN " + tableWithAlias + " on " + condition))
return d
}
func (d *Database) Preload(query string, args ...interface{}) *Database {
d.useSourceDB(d.db.Preload(query, args...))
return d
}
// Select 查询字段
func (d *Database) Select(fields ...any) *Database {
var query string
for _, v := range fields {
field := parseField(v, d.dbType)
//生成查询 sql
if query == "" {
query = field
} else {
query += "," + field
}
}
return d.useSourceDB(d.db.Select(query))
}
// Pluck 查询字段
func (d *Database) Pluck(column any, desc any) *Database {
var query string
field := parseField(column, d.dbType)
//生成查询 sql
if query == "" {
query = field
} else {
query += "," + field
}
return d.useSourceDB(d.db.Pluck(query, desc))
}
// Model 设置模型
func (d *Database) Model(model interface{}) *Database {
return d.useSourceDB(d.db.Model(model))
}
// DB 获取原始的 db
func (d *Database) DB() *gorm.DB {
return d.db
}
// Limit 分页
func (d *Database) Limit(page, pageSize int) *Database {
return d.useSourceDB(d.da.Limit(d.db, int64(page), int64(pageSize)))
}
// Group 分组
func (d *Database) Group(group string) *Database {
return d.useSourceDB(d.db.Group(group))
}
// Order 排序
func (d *Database) Order(order string) *Database {
return d.useSourceDB(d.db.Order(order))
}
// Error 获取错误
func (d *Database) Error() error {
var err error
if d.err != nil {
err = d.err
} else {
err = d.db.Error
}
if err != nil {
PrintCallerInfo(err)
}
return err
}
// PrintCallerInfo 打印调用者信息
func PrintCallerInfo(err error) {
// 获取调用者信息
_, file, line, ok := runtime.Caller(2)
if !ok {
fmt.Println("Failed to retrieve caller information")
return
}
fmt.Printf("Caller file: %s, line: %d", file, line)
}