/
mongo.go
123 lines (100 loc) · 2.9 KB
/
mongo.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package mongo
import (
"context"
"fmt"
"time"
"github.com/infraboard/mcube/v2/ioc"
"github.com/infraboard/mcube/v2/ioc/config/application"
"github.com/infraboard/mcube/v2/ioc/config/log"
"github.com/infraboard/mcube/v2/ioc/config/trace"
"github.com/rs/zerolog"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo"
)
func init() {
ioc.Config().Registry(defaultConfig)
}
var defaultConfig = &mongoDB{
Database: application.Get().Name(),
AuthDB: "admin",
Endpoints: []string{"127.0.0.1:27017"},
EnableTrace: true,
}
type mongoDB struct {
Endpoints []string `toml:"endpoints" json:"endpoints" yaml:"endpoints" env:"ENDPOINTS" envSeparator:","`
UserName string `toml:"username" json:"username" yaml:"username" env:"USERNAME"`
Password string `toml:"password" json:"password" yaml:"password" env:"PASSWORD"`
Database string `toml:"database" json:"database" yaml:"database" env:"DATABASE"`
AuthDB string `toml:"auth_db" json:"auth_db" yaml:"auth_db" env:"AUTH_DB"`
EnableTrace bool `toml:"enable_trace" json:"enable_trace" yaml:"enable_trace" env:"ENABLE_TRACE"`
client *mongo.Client
ioc.ObjectImpl
log *zerolog.Logger
}
func (m *mongoDB) Name() string {
return AppName
}
func (i *mongoDB) Priority() int {
return 698
}
func (m *mongoDB) Init() error {
m.log = log.Sub(m.Name())
conn, err := m.getClient()
if err != nil {
return err
}
m.client = conn
return nil
}
// 关闭数据库连接
func (m *mongoDB) Close(ctx context.Context) error {
if m.client == nil {
return nil
}
return m.client.Disconnect(ctx)
}
func (m *mongoDB) GetAuthDB() string {
if m.AuthDB != "" {
return m.AuthDB
}
return m.Database
}
func (m *mongoDB) GetDB() *mongo.Database {
return m.client.Database(m.Database)
}
// Client 获取一个全局的mongodb客户端连接
func (m *mongoDB) Client() *mongo.Client {
return m.client
}
func (m *mongoDB) getClient() (*mongo.Client, error) {
opts := options.Client()
if m.UserName != "" && m.Password != "" {
cred := options.Credential{
AuthSource: m.GetAuthDB(),
}
cred.Username = m.UserName
cred.Password = m.Password
cred.PasswordSet = true
opts.SetAuth(cred)
}
opts.SetHosts(m.Endpoints)
opts.SetConnectTimeout(5 * time.Second)
if trace.Get().Enable && m.EnableTrace {
m.log.Info().Msg("enable mongodb trace")
opts.Monitor = otelmongo.NewMonitor(
otelmongo.WithCommandAttributeDisabled(true),
)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(time.Second*5))
defer cancel()
// Connect to MongoDB
client, err := mongo.Connect(ctx, opts)
if err != nil {
return nil, fmt.Errorf("new mongodb client error, %s", err)
}
if err = client.Ping(ctx, nil); err != nil {
return nil, fmt.Errorf("ping mongodb server(%s) error, %s", m.Endpoints, err)
}
return client, nil
}