Skip to content

Commit

Permalink
Parse connect_timeout into Dial func
Browse files Browse the repository at this point in the history
Instead of adding Timeout field which could conflict with custom Dial
func.
  • Loading branch information
jackc committed Jan 14, 2018
1 parent 9281f05 commit 2c07b03
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 76 deletions.
22 changes: 16 additions & 6 deletions conn.go
Expand Up @@ -72,7 +72,6 @@ type ConnConfig struct {
Logger Logger
LogLevel int
Dial DialFunc
Timeout time.Duration
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
OnNotice NoticeHandler // Callback function called when a notice response is received.
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
Expand Down Expand Up @@ -224,6 +223,10 @@ func Connect(config ConnConfig) (c *Conn, err error) {
return connect(config, minimalConnInfo)
}

func defaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}

func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
c = new(Conn)

Expand Down Expand Up @@ -260,7 +263,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)

network, address := c.config.networkAddress()
if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial
d := defaultDialer()
c.config.Dial = d.Dial
}

if c.shouldLog(LogLevelInfo) {
Expand Down Expand Up @@ -692,7 +696,9 @@ func ParseURI(uri string) (ConnConfig, error) {
if err != nil {
return cp, err
}
cp.Timeout = time.Duration(timeout) * time.Second
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial
}

err = configSSL(url.Query().Get("sslmode"), &cp)
Expand Down Expand Up @@ -761,11 +767,13 @@ func ParseDSN(s string) (ConnConfig, error) {
case "sslmode":
sslmode = b[2]
case "connect_timeout":
t, err := strconv.ParseInt(b[2], 10, 64)
timeout, err := strconv.ParseInt(b[2], 10, 64)
if err != nil {
return cp, err
}
cp.Timeout = time.Duration(t) * time.Second
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial
default:
cp.RuntimeParams[b[1]] = b[2]
}
Expand Down Expand Up @@ -841,7 +849,9 @@ func ParseEnvLibpq() (ConnConfig, error) {

if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
cc.Timeout = time.Duration(timeout) * time.Second
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cc.Dial = d.Dial
} else {
return cc, err
}
Expand Down
144 changes: 74 additions & 70 deletions conn_test.go
Expand Up @@ -576,7 +576,7 @@ func TestParseDSN(t *testing.T) {
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Timeout: 10 * time.Second,
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
Expand All @@ -585,15 +585,13 @@ func TestParseDSN(t *testing.T) {
}

for i, tt := range tests {
connParams, err := pgx.ParseDSN(tt.url)
actual, err := pgx.ParseDSN(tt.url)
if err != nil {
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
continue
}

if !reflect.DeepEqual(connParams, tt.connParams) {
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
}
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
}
}

Expand Down Expand Up @@ -721,7 +719,7 @@ func TestParseConnectionString(t *testing.T) {
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Timeout: 10 * time.Second,
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
Expand Down Expand Up @@ -819,16 +817,80 @@ func TestParseConnectionString(t *testing.T) {
}

for i, tt := range tests {
connParams, err := pgx.ParseConnectionString(tt.url)
actual, err := pgx.ParseConnectionString(tt.url)
if err != nil {
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
continue
}

if !reflect.DeepEqual(connParams, tt.connParams) {
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
}
}

func testConnConfigEquals(t *testing.T, expected pgx.ConnConfig, actual pgx.ConnConfig, testName string) {
if actual.Host != expected.Host {
t.Errorf("%s: expected Host to be %v got %v", testName, expected.Host, actual.Host)
}
if actual.Port != expected.Port {
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
}
if actual.Port != expected.Port {
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
}
if actual.User != expected.User {
t.Errorf("%s: expected User to be %v got %v", testName, expected.User, actual.User)
}
if actual.Password != expected.Password {
t.Errorf("%s: expected Password to be %v got %v", testName, expected.Password, actual.Password)
}
// Cannot test value of underlying Dialer stuct but can at least test if Dial func is set.
if (actual.Dial != nil) != (expected.Dial != nil) {
t.Errorf("%s: expected Dial mismatch", testName)
}

if !reflect.DeepEqual(actual.RuntimeParams, expected.RuntimeParams) {
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", testName, expected.RuntimeParams, actual.RuntimeParams)
}

tlsTests := []struct {
name string
expected *tls.Config
actual *tls.Config
}{
{
name: "TLSConfig",
expected: expected.TLSConfig,
actual: actual.TLSConfig,
},
{
name: "FallbackTLSConfig",
expected: expected.FallbackTLSConfig,
actual: actual.FallbackTLSConfig,
},
}
for _, tlsTest := range tlsTests {
name := tlsTest.name
expected := tlsTest.expected
actual := tlsTest.actual

if expected == nil && actual != nil {
t.Errorf("%s / %s: expected nil, but it was set", testName, name)
} else if expected != nil && actual == nil {
t.Errorf("%s / %s: expected to be set, but got nil", testName, name)
} else if expected != nil && actual != nil {
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", testName, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
}

if actual.ServerName != expected.ServerName {
t.Errorf("%s / %s: expected ServerName to be %v got %v", testName, name, expected.ServerName, actual.ServerName)
}
}
}

if actual.UseFallbackTLS != expected.UseFallbackTLS {
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", testName, expected.UseFallbackTLS, actual.UseFallbackTLS)
}
}

func TestParseEnvLibpq(t *testing.T) {
Expand Down Expand Up @@ -881,7 +943,7 @@ func TestParseEnvLibpq(t *testing.T) {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
Timeout: 10 * time.Second,
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
RuntimeParams: map[string]string{},
},
},
Expand Down Expand Up @@ -997,71 +1059,13 @@ func TestParseEnvLibpq(t *testing.T) {
}
}

config, err := pgx.ParseEnvLibpq()
actual, err := pgx.ParseEnvLibpq()
if err != nil {
t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
continue
}

if config.Host != tt.config.Host {
t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host)
}
if config.Port != tt.config.Port {
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
}
if config.Port != tt.config.Port {
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
}
if config.User != tt.config.User {
t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User)
}
if config.Password != tt.config.Password {
t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password)
}

if !reflect.DeepEqual(config.RuntimeParams, tt.config.RuntimeParams) {
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", tt.name, tt.config.RuntimeParams, config.RuntimeParams)
}

tlsTests := []struct {
name string
expected *tls.Config
actual *tls.Config
}{
{
name: "TLSConfig",
expected: tt.config.TLSConfig,
actual: config.TLSConfig,
},
{
name: "FallbackTLSConfig",
expected: tt.config.FallbackTLSConfig,
actual: config.FallbackTLSConfig,
},
}
for _, tlsTest := range tlsTests {
name := tlsTest.name
expected := tlsTest.expected
actual := tlsTest.actual

if expected == nil && actual != nil {
t.Errorf("%s / %s: expected nil, but it was set", tt.name, name)
} else if expected != nil && actual == nil {
t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name)
} else if expected != nil && actual != nil {
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
}

if actual.ServerName != expected.ServerName {
t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName)
}
}
}

if config.UseFallbackTLS != tt.config.UseFallbackTLS {
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS)
}
testConnConfigEquals(t, tt.config, actual, tt.name)
}
}

Expand Down

0 comments on commit 2c07b03

Please sign in to comment.