forked from google/go-cloud
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gcppostgres.go
194 lines (172 loc) · 5.63 KB
/
gcppostgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
// Copyright 2018 The Go Cloud Development Kit Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package gcppostgres provides connections to managed PostgreSQL Cloud SQL instances.
// See https://cloud.google.com/sql/docs/postgres/ for more information.
//
// URLs
//
// For postgres.Open, gcppostgres registers for the scheme "gcppostgres".
// The default URL opener will create a connection using the default
// credentials from the environment, as described in
// https://cloud.google.com/docs/authentication/production.
// To customize the URL opener, or for more details on the URL format,
// see URLOpener.
//
// See https://github.com/kainoaseto/go-cloud/concepts/urls/ for background information.
package gcppostgres // import "github.com/kainoaseto/go-cloud/postgres/gcppostgres"
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"contrib.go.opencensus.io/integrations/ocsql"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy"
"github.com/lib/pq"
"github.com/kainoaseto/go-cloud/gcp"
"github.com/kainoaseto/go-cloud/gcp/cloudsql"
"github.com/kainoaseto/go-cloud/postgres"
)
// Scheme is the URL scheme gcppostgres registers its URLOpener under on
// postgres.DefaultMux.
const Scheme = "gcppostgres"
func init() {
postgres.DefaultURLMux().RegisterPostgres(Scheme, new(lazyCredsOpener))
}
// lazyCredsOpener obtains Application Default Credentials on the first call
// to OpenPostgresURL.
type lazyCredsOpener struct {
init sync.Once
opener *URLOpener
err error
}
func (o *lazyCredsOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
o.init.Do(func() {
creds, err := gcp.DefaultCredentials(ctx)
if err != nil {
o.err = err
return
}
client, err := gcp.NewHTTPClient(gcp.DefaultTransport(), creds.TokenSource)
if err != nil {
o.err = err
return
}
certSource := cloudsql.NewCertSource(client)
o.opener = &URLOpener{CertSource: certSource}
})
if o.err != nil {
return nil, fmt.Errorf("gcppostgres open %v: %v", u, o.err)
}
return o.opener.OpenPostgresURL(ctx, u)
}
// URLOpener opens GCP PostgreSQL URLs
// like "gcppostgres://user:password@myproject/us-central1/instanceId/mydb".
type URLOpener struct {
// CertSource specifies how the opener will obtain authentication information.
// CertSource must not be nil.
CertSource proxy.CertSource
// TraceOpts contains options for OpenCensus.
TraceOpts []ocsql.TraceOption
}
// OpenPostgresURL opens a new GCP database connection wrapped with OpenCensus instrumentation.
func (uo *URLOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
if uo.CertSource == nil {
return nil, fmt.Errorf("gcppostgres: URLOpener CertSource is nil")
}
instance, dbName, err := instanceFromURL(u)
if err != nil {
return nil, fmt.Errorf("gcppostgres: open %v: %v", u, err)
}
query := u.Query()
for k := range query {
// Only permit parameters that do not conflict with other behavior.
if k == "sslmode" || k == "sslcert" || k == "sslkey" || k == "sslrootcert" {
return nil, fmt.Errorf("gcppostgres: open: extra parameter %s not allowed", k)
}
}
query.Set("sslmode", "disable")
u2 := new(url.URL)
*u2 = *u
u2.Scheme = "postgres"
u2.Host = "cloudsql"
u2.Path = "/" + dbName
u2.RawQuery = query.Encode()
db := sql.OpenDB(connector{
client: &proxy.Client{
Port: 3307,
Certs: uo.CertSource,
},
instance: instance,
pqConn: u2.String(),
traceOpts: append([]ocsql.TraceOption(nil), uo.TraceOpts...),
})
return db, nil
}
func instanceFromURL(u *url.URL) (instance, db string, _ error) {
path := u.Host + u.Path // everything after scheme but before query or fragment
parts := strings.SplitN(path, "/", 4)
if len(parts) < 4 {
return "", "", fmt.Errorf("%s is not in the form project/region/instance/dbname", path)
}
for _, part := range parts {
if part == "" {
return "", "", fmt.Errorf("%s is not in the form project/region/instance/dbname", path)
}
}
return parts[0] + ":" + parts[1] + ":" + parts[2], parts[3], nil
}
type pqDriver struct {
client *proxy.Client
instance string
traceOpts []ocsql.TraceOption
}
func (d pqDriver) Open(name string) (driver.Conn, error) {
c, _ := d.OpenConnector(name)
return c.Connect(context.Background())
}
func (d pqDriver) OpenConnector(name string) (driver.Connector, error) {
return connector{d.client, d.instance, name, d.traceOpts}, nil
}
type connector struct {
client *proxy.Client
instance string
pqConn string
traceOpts []ocsql.TraceOption
}
func (c connector) Connect(context.Context) (driver.Conn, error) {
conn, err := pq.DialOpen(dialer{c.client, c.instance}, c.pqConn)
if err != nil {
return nil, err
}
return ocsql.WrapConn(conn, c.traceOpts...), nil
}
func (c connector) Driver() driver.Driver {
return pqDriver{c.client, c.instance, c.traceOpts}
}
type dialer struct {
client *proxy.Client
instance string
}
func (d dialer) Dial(network, address string) (net.Conn, error) {
return d.client.Dial(d.instance)
}
func (d dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
return nil, errors.New("gcppostgres: DialTimeout not supported")
}