Skip to content

Commit

Permalink
Make pgconn.ConnectError and pgconn.ParseConfigError public
Browse files Browse the repository at this point in the history
fixes #1773
  • Loading branch information
jackc committed Jan 12, 2024
1 parent 44768b5 commit 5d26bbe
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 35 deletions.
14 changes: 7 additions & 7 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseDSNSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
Expand All @@ -251,7 +251,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
}

settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
Expand All @@ -278,7 +278,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
Expand Down Expand Up @@ -340,7 +340,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con

port, err := parsePort(portStr)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
}

var tlsConfigs []*tls.Config
Expand All @@ -352,7 +352,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
var err error
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
}
}

Expand Down Expand Up @@ -396,7 +396,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "any":
// do nothing
default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}

return config, nil
Expand Down
22 changes: 12 additions & 10 deletions pgconn/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,23 @@ func (pe *PgError) SQLState() string {
return pe.Code
}

type connectError struct {
config *Config
// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
msg string
err error
}

func (e *connectError) Error() string {
func (e *ConnectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return sb.String()
}

func (e *connectError) Unwrap() error {
func (e *ConnectError) Unwrap() error {
return e.err
}

Expand All @@ -88,21 +89,22 @@ func (e *connLockError) Error() string {
return e.status
}

type parseConfigError struct {
connString string
// ParseConfigError is the error returned when a connection string cannot be parsed.
type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
msg string
err error
}

func (e *parseConfigError) Error() string {
connString := redactPW(e.connString)
func (e *ParseConfigError) Error() string {
connString := redactPW(e.ConnString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}

func (e *parseConfigError) Unwrap() error {
func (e *ParseConfigError) Unwrap() error {
return e.err
}

Expand Down
4 changes: 2 additions & 2 deletions pgconn/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
package pgconn

func NewParseConfigError(conn, msg string, err error) error {
return &parseConfigError{
connString: conn,
return &ParseConfigError{
ConnString: conn,
msg: msg,
err: err,
}
Expand Down
32 changes: 16 additions & 16 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
return nil, &connectError{config: config, msg: "hostname resolving error", err: err}
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
}

if len(fallbackConfigs) == 0 {
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}

foundBestServer := false
Expand All @@ -178,7 +178,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
Expand All @@ -189,7 +189,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr, ok := err.(*connectError); ok {
} else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
Expand All @@ -199,7 +199,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
}
}

Expand All @@ -211,7 +211,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
}
}

Expand Down Expand Up @@ -283,7 +283,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address)
if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}

pgConn.conn = netConn
Expand All @@ -295,7 +295,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
return nil, &connectError{config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
}

pgConn.conn = nbTLSConn
Expand Down Expand Up @@ -336,7 +336,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
}

for {
Expand All @@ -346,7 +346,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
}

switch msg := msg.(type) {
Expand All @@ -359,26 +359,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed GSS auth", err: err}
return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
Expand All @@ -396,7 +396,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil
}
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
}
}
return pgConn, nil
Expand All @@ -407,7 +407,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg)
default:
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "received unexpected message", err: err}
return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
}
}
}
Expand Down

0 comments on commit 5d26bbe

Please sign in to comment.