diff --git a/pgutils/connector.go b/pgutils/connector.go index f859c0b..0fa35a5 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -127,23 +127,23 @@ func MustConnectDB(conn driver.Connector) *sqlx.DB { return db } -// addSearchPathToURL returns a copy of u with search_path set in the query string. -// It returns an error if search_path is already present. +// addSearchPathToURL returns a copy of u with search_path set in the options parameter +// of the query string. It returns an error if the search_path or options parameter is +// already present. func addSearchPathToURL(rawURL string, searchPath string) (string, error) { u, err := url.Parse(rawURL) if err != nil { return "", fmt.Errorf("url string failed to parse while adding search path: %w", err) } - if searchPath == "" { - return u.String(), nil - } - q := u.Query() - if v := q.Get("search_path"); v != "" { + if v, ok := q["search_path"]; ok { return "", fmt.Errorf("search_path already set to %q", v) } - q.Set("search_path", searchPath) + if v, ok := q["options"]; ok { + return "", fmt.Errorf("options already set to %q", v) + } + q.Set("options", fmt.Sprintf("-csearch_path=%s", searchPath)) u.RawQuery = q.Encode() return u.String(), nil } @@ -247,6 +247,7 @@ func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL, onTo supportedParams := map[string]struct{}{ "assume_role_arn": {}, "assume_role_session_name": {}, + "search_path": {}, } for k := range q { if _, ok := supportedParams[k]; !ok { @@ -278,7 +279,7 @@ func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL, onTo creds = aws.NewCredentialsCache(assumeProvider) } - return &rdsIAMConnectionStringProvider{ + var p ConnectionStringProvider = &rdsIAMConnectionStringProvider{ Region: awsCfg.Region, RDSEndpoint: net.JoinHostPort(host, port), User: user, @@ -287,5 +288,14 @@ func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL, onTo AssumeRoleARN: assumeRoleARN, AssumeRoleSessionName: sessionName, OnTokenSign: onTokenSign, - }, nil + } + + if searchPath, ok := q["search_path"]; ok { + if len(searchPath) > 1 { + return nil, fmt.Errorf("Multiple search_path values specified") + } + p = WithSchemaSearchPath(p, searchPath[0]) + } + + return p, nil }