From 2c07b0308759517c9e367ec774e9dc18d51b71c7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jan 2018 18:02:13 -0600 Subject: [PATCH] Parse connect_timeout into Dial func Instead of adding Timeout field which could conflict with custom Dial func. --- conn.go | 22 +++++--- conn_test.go | 144 ++++++++++++++++++++++++++------------------------- 2 files changed, 90 insertions(+), 76 deletions(-) diff --git a/conn.go b/conn.go index 6fe4ba728..125d90329 100644 --- a/conn.go +++ b/conn.go @@ -72,7 +72,6 @@ type ConnConfig struct { Logger Logger LogLevel int Dial DialFunc - Timeout time.Duration RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) OnNotice NoticeHandler // Callback function called when a notice response is received. CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. @@ -224,6 +223,10 @@ func Connect(config ConnConfig) (c *Conn, err error) { return connect(config, minimalConnInfo) } +func defaultDialer() *net.Dialer { + return &net.Dialer{KeepAlive: 5 * time.Minute} +} + func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { c = new(Conn) @@ -260,7 +263,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) network, address := c.config.networkAddress() if c.config.Dial == nil { - c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial + d := defaultDialer() + c.config.Dial = d.Dial } if c.shouldLog(LogLevelInfo) { @@ -692,7 +696,9 @@ func ParseURI(uri string) (ConnConfig, error) { if err != nil { return cp, err } - cp.Timeout = time.Duration(timeout) * time.Second + d := defaultDialer() + d.Timeout = time.Duration(timeout) * time.Second + cp.Dial = d.Dial } err = configSSL(url.Query().Get("sslmode"), &cp) @@ -761,11 +767,13 @@ func ParseDSN(s string) (ConnConfig, error) { case "sslmode": sslmode = b[2] case "connect_timeout": - t, err := strconv.ParseInt(b[2], 10, 64) + timeout, err := strconv.ParseInt(b[2], 10, 64) if err != nil { return cp, err } - cp.Timeout = time.Duration(t) * time.Second + d := defaultDialer() + d.Timeout = time.Duration(timeout) * time.Second + cp.Dial = d.Dial default: cp.RuntimeParams[b[1]] = b[2] } @@ -841,7 +849,9 @@ func ParseEnvLibpq() (ConnConfig, error) { if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" { if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil { - cc.Timeout = time.Duration(timeout) * time.Second + d := defaultDialer() + d.Timeout = time.Duration(timeout) * time.Second + cc.Dial = d.Dial } else { return cc, err } diff --git a/conn_test.go b/conn_test.go index 696a4003f..d82f4bdec 100644 --- a/conn_test.go +++ b/conn_test.go @@ -576,7 +576,7 @@ func TestParseDSN(t *testing.T) { TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, - Timeout: 10 * time.Second, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, UseFallbackTLS: true, FallbackTLSConfig: nil, RuntimeParams: map[string]string{}, @@ -585,15 +585,13 @@ func TestParseDSN(t *testing.T) { } for i, tt := range tests { - connParams, err := pgx.ParseDSN(tt.url) + actual, err := pgx.ParseDSN(tt.url) if err != nil { t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) continue } - if !reflect.DeepEqual(connParams, tt.connParams) { - t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) - } + testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i)) } } @@ -721,7 +719,7 @@ func TestParseConnectionString(t *testing.T) { TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, - Timeout: 10 * time.Second, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, UseFallbackTLS: true, FallbackTLSConfig: nil, RuntimeParams: map[string]string{}, @@ -819,16 +817,80 @@ func TestParseConnectionString(t *testing.T) { } for i, tt := range tests { - connParams, err := pgx.ParseConnectionString(tt.url) + actual, err := pgx.ParseConnectionString(tt.url) if err != nil { t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) continue } - if !reflect.DeepEqual(connParams, tt.connParams) { - t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) + testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i)) + } +} + +func testConnConfigEquals(t *testing.T, expected pgx.ConnConfig, actual pgx.ConnConfig, testName string) { + if actual.Host != expected.Host { + t.Errorf("%s: expected Host to be %v got %v", testName, expected.Host, actual.Host) + } + if actual.Port != expected.Port { + t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port) + } + if actual.Port != expected.Port { + t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port) + } + if actual.User != expected.User { + t.Errorf("%s: expected User to be %v got %v", testName, expected.User, actual.User) + } + if actual.Password != expected.Password { + t.Errorf("%s: expected Password to be %v got %v", testName, expected.Password, actual.Password) + } + // Cannot test value of underlying Dialer stuct but can at least test if Dial func is set. + if (actual.Dial != nil) != (expected.Dial != nil) { + t.Errorf("%s: expected Dial mismatch", testName) + } + + if !reflect.DeepEqual(actual.RuntimeParams, expected.RuntimeParams) { + t.Errorf("%s: expected RuntimeParams to be %#v got %#v", testName, expected.RuntimeParams, actual.RuntimeParams) + } + + tlsTests := []struct { + name string + expected *tls.Config + actual *tls.Config + }{ + { + name: "TLSConfig", + expected: expected.TLSConfig, + actual: actual.TLSConfig, + }, + { + name: "FallbackTLSConfig", + expected: expected.FallbackTLSConfig, + actual: actual.FallbackTLSConfig, + }, + } + for _, tlsTest := range tlsTests { + name := tlsTest.name + expected := tlsTest.expected + actual := tlsTest.actual + + if expected == nil && actual != nil { + t.Errorf("%s / %s: expected nil, but it was set", testName, name) + } else if expected != nil && actual == nil { + t.Errorf("%s / %s: expected to be set, but got nil", testName, name) + } else if expected != nil && actual != nil { + if actual.InsecureSkipVerify != expected.InsecureSkipVerify { + t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", testName, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify) + } + + if actual.ServerName != expected.ServerName { + t.Errorf("%s / %s: expected ServerName to be %v got %v", testName, name, expected.ServerName, actual.ServerName) + } } } + + if actual.UseFallbackTLS != expected.UseFallbackTLS { + t.Errorf("%s: expected UseFallbackTLS to be %v got %v", testName, expected.UseFallbackTLS, actual.UseFallbackTLS) + } } func TestParseEnvLibpq(t *testing.T) { @@ -881,7 +943,7 @@ func TestParseEnvLibpq(t *testing.T) { TLSConfig: &tls.Config{InsecureSkipVerify: true}, UseFallbackTLS: true, FallbackTLSConfig: nil, - Timeout: 10 * time.Second, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, RuntimeParams: map[string]string{}, }, }, @@ -997,71 +1059,13 @@ func TestParseEnvLibpq(t *testing.T) { } } - config, err := pgx.ParseEnvLibpq() + actual, err := pgx.ParseEnvLibpq() if err != nil { t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err) continue } - if config.Host != tt.config.Host { - t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host) - } - if config.Port != tt.config.Port { - t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port) - } - if config.Port != tt.config.Port { - t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port) - } - if config.User != tt.config.User { - t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User) - } - if config.Password != tt.config.Password { - t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password) - } - - if !reflect.DeepEqual(config.RuntimeParams, tt.config.RuntimeParams) { - t.Errorf("%s: expected RuntimeParams to be %#v got %#v", tt.name, tt.config.RuntimeParams, config.RuntimeParams) - } - - tlsTests := []struct { - name string - expected *tls.Config - actual *tls.Config - }{ - { - name: "TLSConfig", - expected: tt.config.TLSConfig, - actual: config.TLSConfig, - }, - { - name: "FallbackTLSConfig", - expected: tt.config.FallbackTLSConfig, - actual: config.FallbackTLSConfig, - }, - } - for _, tlsTest := range tlsTests { - name := tlsTest.name - expected := tlsTest.expected - actual := tlsTest.actual - - if expected == nil && actual != nil { - t.Errorf("%s / %s: expected nil, but it was set", tt.name, name) - } else if expected != nil && actual == nil { - t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name) - } else if expected != nil && actual != nil { - if actual.InsecureSkipVerify != expected.InsecureSkipVerify { - t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify) - } - - if actual.ServerName != expected.ServerName { - t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName) - } - } - } - - if config.UseFallbackTLS != tt.config.UseFallbackTLS { - t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS) - } + testConnConfigEquals(t, tt.config, actual, tt.name) } }