Permalink
Browse files

Implement many libpq environment variables

Use these to make the tests less dependent fewer compiled-in defaults
that cannot be overridden.

By default, the database "pqgotest" is used for tests now.  Unlike
libpq, "localhost" is used by default instead of the socket directory,
because there is a large split between vanilla PostgreSQL (which
places things in /tmp) and the way most packaged PostgreSQL places the
unix socket (/var/run/postgresql).  Also unlike libpq, we do not have the
advantage of the default configuration on the system being burned into
the driver, so "localhost" seems like a reasonable compromise.

A way to overcome that might be to call out to pg_config or link
against libpq, but that is not very good from a dependency perspective
and defeats the point of implementing a driver.

To run tests, for example, one can now write:

$ PGHOST=/var/run/postgresql/ go test pq

Signed-off-by: Dan Farina <drfarina@acm.org>
  • Loading branch information...
1 parent 4d8662f commit 6c7918fdcbd1dacbf6104ab2a6760539c9c45a42 Dan Farina committed Apr 6, 2012
Showing with 138 additions and 36 deletions.
  1. +96 −1 conn.go
  2. +42 −35 conn_test.go
View
97 conn.go
@@ -10,6 +10,9 @@ import (
"fmt"
"io"
"net"
+ "os"
+ "os/user"
+ "path"
"strconv"
"strings"
)
@@ -38,8 +41,29 @@ func Open(name string) (_ driver.Conn, err error) {
defer errRecover(&err)
o := make(Values)
+
+ // A number of defaults are applied here, in this order:
+ //
+ // * Very low precedence defaults applied in every situation
+ // * Environment variables
+ // * Explicitly passed connection information
o.Set("host", "localhost")
o.Set("port", "5432")
+
+ // Default the username, but ignore errors, because a user
+ // passed in via environment variable or connection string
+ // would be okay. This can result in connections failing
+ // *sometimes* if the client relies on being able to determine
+ // the current username and there are intermittent problems.
+ u, err := user.Current()
+ if err == nil {
+ o.Set("user", u.Username)
+ }
+
+ for k, v := range parseEnviron(os.Environ()) {
+ o.Set(k, v)
+ }
+
parseOpts(name, o)
c, err := net.Dial(network(o))
@@ -57,7 +81,8 @@ func network(o Values) (string, string) {
host := o.Get("host")
if strings.HasPrefix(host, "/") {
- return "unix", host
+ sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
+ return "unix", sockPath
}
return "tcp", host + ":" + o.Get("port")
@@ -505,3 +530,73 @@ func md5s(s string) string {
h.Write([]byte(s))
return fmt.Sprintf("%x", h.Sum(nil))
}
+
+// parseEnviron tries to mimic some of libpq's environment handling
+//
+// To ease testing, it does not directly reference os.Environ, but is
+// designed to accept its output.
+//
+// Environment-set connection information is intended to have a higher
+// precedence than a library default but lower than any explicitly
+// passed information (such as in the URL or connection string).
+func parseEnviron(env []string) (out map[string]string) {
+ out = make(map[string]string)
+
+ for _, v := range env {
+ parts := strings.SplitN(v, "=", 2)
+
+ accrue := func(keyname string) {
+ out[keyname] = parts[1]
+ }
+
+ // The order of these is the same as is seen in the
+ // PostgreSQL 9.1 manual, with omissions briefly
+ // noted.
+ switch parts[0] {
+ case "PGHOST":
+ accrue("host")
+ case "PGHOSTADDR":
+ accrue("hostaddr")
+ case "PGPORT":
+ accrue("port")
+ case "PGDATABASE":
+ accrue("dbname")
+ case "PGUSER":
+ accrue("user")
+ case "PGPASSWORD":
+ accrue("password")
+ // skip PGPASSFILE, PGSERVICE, PGSERVICEFILE,
+ // PGREALM
+ case "PGOPTIONS":
+ accrue("options")
+ case "PGAPPNAME":
+ accrue("application_name")
+ case "PGSSLMODE":
+ accrue("sslmode")
+ case "PGREQUIRESSL":
+ accrue("requiressl")
+ case "PGSSLCERT":
+ accrue("sslcert")
+ case "PGSSLKEY":
+ accrue("sslkey")
+ case "PGSSLROOTCERT":
+ accrue("sslrootcert")
+ case "PGSSLCRL":
+ accrue("sslcrl")
+ case "PGREQUIREPEER":
+ accrue("requirepeer")
+ case "PGKRBSRVNAME":
+ accrue("krbsrvname")
+ case "PGGSSLIB":
+ accrue("gsslib")
+ case "PGCONNECT_TIMEOUT":
+ accrue("connect_timeout")
+ case "PGCLIENTENCODING":
+ accrue("client_encoding")
+ // skip PGDATESTYLE, PGTZ, PGGEQO, PGSYSCONFDIR,
+ // PGLOCALEDIR
+ }
+ }
+
+ return out
+}
View
@@ -4,21 +4,37 @@ import (
"database/sql"
"database/sql/driver"
"io"
+ "os"
"reflect"
"testing"
"time"
)
-var cs = "user=pqgotest sslmode=disable"
+func openTestConn(t *testing.T) *sql.DB {
+ datname := os.Getenv("PGDATABASE")
+ sslmode := os.Getenv("PGSSLMODE")
-func TestExec(t *testing.T) {
- db, err := sql.Open("postgres", cs)
+ if datname == "" {
+ os.Setenv("PGDATABASE", "pqgotest")
+ }
+
+ if sslmode == "" {
+ os.Setenv("PGSSLMODE", "disable")
+ }
+
+ conn, err := sql.Open("postgres", "")
if err != nil {
t.Fatal(err)
}
+
+ return conn
+}
+
+func TestExec(t *testing.T) {
+ db := openTestConn(t)
defer db.Close()
- _, err = db.Exec("CREATE TEMP TABLE temp (a int)")
+ _, err := db.Exec("CREATE TEMP TABLE temp (a int)")
if err != nil {
t.Fatal(err)
}
@@ -34,10 +50,7 @@ func TestExec(t *testing.T) {
}
func TestStatment(t *testing.T) {
- db, err := sql.Open("postgres", cs)
- if err != nil {
- t.Fatal(err)
- }
+ db := openTestConn(t)
defer db.Close()
st, err := db.Prepare("SELECT 1")
@@ -94,10 +107,8 @@ func TestStatment(t *testing.T) {
}
func TestRowsCloseBeforeDone(t *testing.T) {
- db, err := sql.Open("postgres", cs)
- if err != nil {
- t.Fatal(err)
- }
+ db := openTestConn(t)
+ defer db.Close()
r, err := db.Query("SELECT 1")
if err != nil {
@@ -119,10 +130,7 @@ func TestRowsCloseBeforeDone(t *testing.T) {
}
func TestEncodeDecode(t *testing.T) {
- db, err := sql.Open("postgres", cs)
- if err != nil {
- t.Fatal(err)
- }
+ db := openTestConn(t)
defer db.Close()
q := `
@@ -182,10 +190,7 @@ func TestEncodeDecode(t *testing.T) {
}
func TestNoData(t *testing.T) {
- db, err := sql.Open("postgres", cs)
- if err != nil {
- t.Fatal(err)
- }
+ db := openTestConn(t)
defer db.Close()
st, err := db.Prepare("SELECT 1 WHERE true = false")
@@ -209,7 +214,9 @@ func TestNoData(t *testing.T) {
}
func TestPGError(t *testing.T) {
- db, err := sql.Open("postgres", "user=asdf")
+ // Don't use the normal connection setup, this is intended to
+ // blow up in the startup packet from a non-existent user.
+ db, err := sql.Open("postgres", "user=thisuserreallydoesntexist")
if err != nil {
t.Fatal(err)
}
@@ -250,14 +257,11 @@ func TestBadConn(t *testing.T) {
}
func TestErrorOnExec(t *testing.T) {
- db, err := sql.Open("postgres", cs)
- if err != nil {
- t.Fatal(err)
- }
+ db := openTestConn(t)
defer db.Close()
sql := "DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END; $$;"
- _, err = db.Exec(sql)
+ _, err := db.Exec(sql)
_, ok := err.(*PGError)
if !ok {
t.Fatalf("expected PGError, was: %#v", err)
@@ -270,10 +274,7 @@ func TestErrorOnExec(t *testing.T) {
}
func TestErrorOnQuery(t *testing.T) {
- db, err := sql.Open("postgres", cs)
- if err != nil {
- t.Fatal(err)
- }
+ db := openTestConn(t)
defer db.Close()
sql := "DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END; $$;"
@@ -302,12 +303,10 @@ func TestErrorOnQuery(t *testing.T) {
}
func TestBindError(t *testing.T) {
- db, err := sql.Open("postgres", cs)
- if err != nil {
- t.Fatal(err)
- }
+ db := openTestConn(t)
+ defer db.Close()
- _, err = db.Exec("create temp table test (i integer)")
+ _, err := db.Exec("create temp table test (i integer)")
if err != nil {
t.Fatal(err)
}
@@ -323,3 +322,11 @@ func TestBindError(t *testing.T) {
t.Fatal(err)
}
}
+
+func TestParseEnviron(t *testing.T) {
+ expected := map[string]string{"dbname": "hello", "user": "goodbye"}
+ results := parseEnviron([]string{"PGDATABASE=hello", "PGUSER=goodbye"})
+ if !reflect.DeepEqual(expected, results) {
+ t.Fatalf("Expected: %#v Got: %#v", expected, results)
+ }
+}

0 comments on commit 6c7918f

Please sign in to comment.