/
orm.go
123 lines (102 loc) · 2.59 KB
/
orm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package database
import (
"context"
"database/sql"
"fmt"
"github.com/gookit/color"
"github.com/pkg/errors"
"github.com/goravel/framework/contracts/config"
ormcontract "github.com/goravel/framework/contracts/database/orm"
databasegorm "github.com/goravel/framework/database/gorm"
"github.com/goravel/framework/database/orm"
)
type OrmImpl struct {
ctx context.Context
config config.Config
connection string
query ormcontract.Query
queries map[string]ormcontract.Query
}
func NewOrmImpl(ctx context.Context, config config.Config, connection string, query ormcontract.Query) (*OrmImpl, error) {
return &OrmImpl{
ctx: ctx,
config: config,
connection: connection,
query: query,
queries: map[string]ormcontract.Query{
connection: query,
},
}, nil
}
func (r *OrmImpl) Connection(name string) ormcontract.Orm {
if name == "" {
name = r.config.GetString("database.default")
}
if instance, exist := r.queries[name]; exist {
return &OrmImpl{
ctx: r.ctx,
config: r.config,
connection: name,
query: instance,
queries: r.queries,
}
}
queue, err := databasegorm.InitializeQuery(r.ctx, r.config, name)
if err != nil || queue == nil {
color.Redln(fmt.Sprintf("[Orm] Init %s connection error: %v", name, err))
return nil
}
r.queries[name] = queue
return &OrmImpl{
ctx: r.ctx,
config: r.config,
connection: name,
query: queue,
queries: r.queries,
}
}
func (r *OrmImpl) DB() (*sql.DB, error) {
query := r.Query().(*databasegorm.QueryImpl)
return query.Instance().DB()
}
func (r *OrmImpl) Query() ormcontract.Query {
return r.query
}
func (r *OrmImpl) Factory() ormcontract.Factory {
return NewFactoryImpl(r.Query())
}
func (r *OrmImpl) Observe(model any, observer ormcontract.Observer) {
orm.Observers = append(orm.Observers, orm.Observer{
Model: model,
Observer: observer,
})
}
func (r *OrmImpl) Transaction(txFunc func(tx ormcontract.Transaction) error) error {
tx, err := r.Query().Begin()
if err != nil {
return err
}
if err := txFunc(tx); err != nil {
if err := tx.Rollback(); err != nil {
return errors.Wrapf(err, "rollback error: %v", err)
}
return err
} else {
return tx.Commit()
}
}
func (r *OrmImpl) WithContext(ctx context.Context) ormcontract.Orm {
for _, query := range r.queries {
query := query.(*databasegorm.QueryImpl)
query.SetContext(ctx)
}
query := r.query.(*databasegorm.QueryImpl)
query.SetContext(ctx)
return &OrmImpl{
ctx: ctx,
config: r.config,
connection: r.connection,
query: query,
queries: r.queries,
}
}