/
driver.go
121 lines (107 loc) · 3.58 KB
/
driver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package pgtenant
import (
"context"
"database/sql"
"database/sql/driver"
"github.com/lib/pq"
"github.com/pkg/errors"
)
// Driver implements database/sql/driver.Driver and driver.DriverContext.
type Driver struct {
// TenantIDCol is the name of the column in all tables of the db schema
// whose value is the tenant ID.
TenantIDCol string
// Whitelist maps SQL query strings to the output expected when transforming them.
// It serves double-duty here:
//
// 1. It is a whitelist of permitted queries.
// Database connections created from this driver will refuse to execute a query
// unless it appears in this whitelist or is "escaped"
// by attaching it to a context object using WithQuery.
//
// 2. It is a cache of precomputed transforms.
//
// The whitelist is consulted by exact string matching
// (modulo some minimal whitespace trimming)
// using the query string passed to QueryContext or ExecContext.
//
// The value used here should also be used in a unit test that calls TransformTester.
// That will ensure the pre- and post-transform queries are correct.
Whitelist map[string]Transformed
dynamicCache queryCache
}
// Transformed is the output of the transformer:
// a transformed query and the number of the positional parameter added for a tenant-ID value.
type Transformed struct {
Query string
Num int
}
// assert *Driver satisfies the driver.Driver and driver.DriverContext interfaces.
var (
_ driver.Driver = (*Driver)(nil)
_ driver.DriverContext = (*Driver)(nil)
)
// Open implements driver.Driver.Open.
func (d *Driver) Open(name string) (driver.Conn, error) {
connector, err := d.OpenConnector(name)
if err != nil {
return nil, err
}
return connector.Connect(context.Background())
}
// OpenConnector implements driver.DriverContext.OpenConnector.
func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
c, err := pq.NewConnector(name)
return &Connector{nested: c, driver: d}, err
}
// Connector implements driver.Connector.
type Connector struct {
nested *pq.Connector
driver *Driver
}
// assert *Connector satisfies the driver.Connector interface.
var _ driver.Connector = (*Connector)(nil)
// Connect implements driver.Connector.Connect.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
nestedConn, err := c.nested.Connect(ctx)
if err != nil {
return nil, errors.Wrap(err, "connecting to database")
}
nestedCtx, ok := nestedConn.(ctxConn)
if !ok {
return nil, errors.New("connection does not support context")
}
conn := &Conn{
ctxConn: nestedCtx,
driver: c.driver,
}
return conn, nil
}
// Driver implements driver.Connector.Driver.
func (c *Connector) Driver() driver.Driver { return c.driver }
// Open is a convenient shorthand for either of these two sequences:
//
// driver := &pgtenant.Driver{TenantIDCol: tenantIDCol, Whitelist: whitelist}
// sql.Register(driverName, driver)
// db, err := sql.Open(driverName, dsn)
//
// and
//
// driver := &pgtenant.Driver{TenantIDCol: tenantIDCol, Whitelist: whitelist}
// connector, err := driver.OpenConnector(dsn)
// if err != nil { ... }
// db := sql.OpenDB(connector)
//
// The first sequence creates a reusable driver object that can open multiple different databases.
// The second sequence creates an additional reusable connector object that can open the same database multiple times.
func Open(dsn, tenantIDCol string, whitelist map[string]Transformed) (*sql.DB, error) {
d := &Driver{
TenantIDCol: tenantIDCol,
Whitelist: whitelist,
}
c, err := d.OpenConnector(dsn)
if err != nil {
return nil, err
}
return sql.OpenDB(c), nil
}