Skip to content

Commit

Permalink
Merge branch 'timeout' of https://github.com/cyberdelia/pgx into cybe…
Browse files Browse the repository at this point in the history
…rdelia-timeout
  • Loading branch information
jackc committed Jan 13, 2018
2 parents 3707b79 + 1bec450 commit 9281f05
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
29 changes: 27 additions & 2 deletions conn.go
Expand Up @@ -72,6 +72,7 @@ type ConnConfig struct {
Logger Logger
LogLevel int
Dial DialFunc
Timeout time.Duration
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
OnNotice NoticeHandler // Callback function called when a notice response is received.
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
Expand Down Expand Up @@ -259,7 +260,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)

network, address := c.config.networkAddress()
if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial
}

if c.shouldLog(LogLevelInfo) {
Expand Down Expand Up @@ -686,13 +687,22 @@ func ParseURI(uri string) (ConnConfig, error) {
}
cp.Database = strings.TrimLeft(url.Path, "/")

if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
if err != nil {
return cp, err
}
cp.Timeout = time.Duration(timeout) * time.Second
}

err = configSSL(url.Query().Get("sslmode"), &cp)
if err != nil {
return cp, err
}

ignoreKeys := map[string]struct{}{
"sslmode": {},
"sslmode": {},
"connect_timeout": {},
}

cp.RuntimeParams = make(map[string]string)
Expand Down Expand Up @@ -750,6 +760,12 @@ func ParseDSN(s string) (ConnConfig, error) {
cp.Database = b[2]
case "sslmode":
sslmode = b[2]
case "connect_timeout":
t, err := strconv.ParseInt(b[2], 10, 64)
if err != nil {
return cp, err
}
cp.Timeout = time.Duration(t) * time.Second
default:
cp.RuntimeParams[b[1]] = b[2]
}
Expand Down Expand Up @@ -787,6 +803,7 @@ func ParseConnectionString(s string) (ConnConfig, error) {
// PGPASSWORD
// PGSSLMODE
// PGAPPNAME
// PGCONNECT_TIMEOUT
//
// Important TLS Security Notes:
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
Expand Down Expand Up @@ -822,6 +839,14 @@ func ParseEnvLibpq() (ConnConfig, error) {
cc.User = os.Getenv("PGUSER")
cc.Password = os.Getenv("PGPASSWORD")

if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
cc.Timeout = time.Duration(timeout) * time.Second
} else {
return cc, err
}
}

sslmode := os.Getenv("PGSSLMODE")

err := configSSL(sslmode, &cc)
Expand Down
44 changes: 38 additions & 6 deletions conn_test.go
Expand Up @@ -567,6 +567,21 @@ func TestParseDSN(t *testing.T) {
},
},
},
{
url: "user=jack host=localhost dbname=mydb connect_timeout=10",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Timeout: 10 * time.Second,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
}

for i, tt := range tests {
Expand Down Expand Up @@ -697,6 +712,21 @@ func TestParseConnectionString(t *testing.T) {
},
},
},
{
url: "postgres://jack@localhost/mydb?connect_timeout=10",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Timeout: 10 * time.Second,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
connParams: pgx.ConnConfig{
Expand Down Expand Up @@ -802,7 +832,7 @@ func TestParseConnectionString(t *testing.T) {
}

func TestParseEnvLibpq(t *testing.T) {
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"}
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}

savedEnv := make(map[string]string)
for _, n := range pgEnvvars {
Expand Down Expand Up @@ -835,11 +865,12 @@ func TestParseEnvLibpq(t *testing.T) {
{
name: "Normal PG vars",
envvars: map[string]string{
"PGHOST": "123.123.123.123",
"PGPORT": "7777",
"PGDATABASE": "foo",
"PGUSER": "bar",
"PGPASSWORD": "baz",
"PGHOST": "123.123.123.123",
"PGPORT": "7777",
"PGDATABASE": "foo",
"PGUSER": "bar",
"PGPASSWORD": "baz",
"PGCONNECT_TIMEOUT": "10",
},
config: pgx.ConnConfig{
Host: "123.123.123.123",
Expand All @@ -850,6 +881,7 @@ func TestParseEnvLibpq(t *testing.T) {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
Timeout: 10 * time.Second,
RuntimeParams: map[string]string{},
},
},
Expand Down

0 comments on commit 9281f05

Please sign in to comment.