/
postgres.go
141 lines (120 loc) · 4.1 KB
/
postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/XSAM/otelsql"
"github.com/georgysavva/scany/sqlscan"
_ "github.com/lib/pq" //nolint
"github.com/pkg/errors"
"go.opentelemetry.io/otel"
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
)
// Postgres helps interact with the Postgres database
type Postgres struct {
connStr string
dbConnection *sql.DB
config *Config
}
var _ DBConn = (*Postgres)(nil)
const postgresDriver = "postgres"
const instrumentationName = "storage"
// New returns a storage connecting to the given Postgres database.
func New(config *Config) *Postgres {
postgres := new(Postgres)
postgres.config = config
postgres.connStr = createConnectionString(config.DBHost, config.DBPort, config.DBName, config.DBUser, config.DBPassword, config.DBSchema)
return postgres
}
// Connect will connect to our Postgres database
func (p *Postgres) Connect(ctx context.Context) error {
// Register an OTel driver
driverName, err := otelsql.Register(postgresDriver, otelsql.WithAttributes(semconv.DBSystemPostgreSQL))
if err != nil {
return errors.Wrap(err, "failed to hook the tracer to the database driver")
}
// open the connection and connect to the database
db, err := sql.Open(driverName, p.connStr)
if err != nil {
return errors.Wrap(err, "failed to open connection")
}
// let us test the connection
err = db.PingContext(ctx)
if err != nil {
return errors.Wrap(err, "failed to ping database connection")
}
// set connection setting
db.SetMaxOpenConns(p.config.MaxOpenConnections)
db.SetMaxIdleConns(p.config.MaxIdleConnections)
db.SetConnMaxLifetime(p.config.ConnectionMaxLifetime)
// set the db handle
p.dbConnection = db
return nil
}
// createConnectionString will create the Postgres connection string from the
// supplied connection details
func createConnectionString(host string, port int, name, user string, password string, schema string) string {
info := fmt.Sprintf("host=%s port=%d user=%s dbname=%s sslmode=disable", host, port, user, name)
// The Postgres driver gets confused in cases where the user has no password
// set but a password is passed, so only set password if its non-empty
if password != "" {
info += fmt.Sprintf(" password=%s", password)
}
if schema != "" {
info += fmt.Sprintf(" search_path=%s", schema)
}
return info
}
// Exec executes a sql query without returning rows against the database
func (p *Postgres) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
// Create a span
tracer := otel.GetTracerProvider()
spanCtx, span := tracer.Tracer(instrumentationName).Start(ctx, "storage.exec")
defer span.End()
return p.dbConnection.ExecContext(spanCtx, query, args...)
}
// BeginTx starts a new database transaction
func (p *Postgres) BeginTx(ctx context.Context, txOptions *sql.TxOptions) (*sql.Tx, error) {
// Create a span
tracer := otel.GetTracerProvider()
spanCtx, span := tracer.Tracer(instrumentationName).Start(ctx, "storage.beginTx")
defer span.End()
return p.dbConnection.BeginTx(spanCtx, txOptions)
}
// SelectAll fetches rows
func (p *Postgres) SelectAll(ctx context.Context, dst interface{}, query string, args ...interface{}) error {
// Create a span
tracer := otel.GetTracerProvider()
spanCtx, span := tracer.Tracer(instrumentationName).Start(ctx, "storage.selectAll")
defer span.End()
err := sqlscan.Select(spanCtx, p.dbConnection, dst, query, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return err
}
return nil
}
// Select fetches only one row
func (p *Postgres) Select(ctx context.Context, dst interface{}, query string, args ...interface{}) error {
// Create a span
tracer := otel.GetTracerProvider()
spanCtx, span := tracer.Tracer(instrumentationName).Start(ctx, "storage.select")
defer span.End()
err := sqlscan.Get(spanCtx, p.dbConnection, dst, query, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return err
}
return nil
}
// Disconnect the database connection.
func (p *Postgres) Disconnect(ctx context.Context) error {
if p.dbConnection == nil {
return nil
}
return p.dbConnection.Close()
}