diff --git a/postgres.go b/postgres.go index e1e6c65..ae1a2e3 100644 --- a/postgres.go +++ b/postgres.go @@ -34,7 +34,11 @@ func New(u *url.URL) (waitfor.Resource, error) { } func (s *Postgres) Test(ctx context.Context) error { - db, err := sql.Open(s.url.Scheme, strings.TrimPrefix(s.url.String(), Scheme+"://")) + // Always use "postgres" as the driver name for sql.Open, regardless of the URL scheme + // Remove the scheme and "://" from the URL to get the connection string + connStr := strings.TrimPrefix(s.url.String(), s.url.Scheme+"://") + + db, err := sql.Open(Scheme, connStr) if err != nil { return err diff --git a/postgres_test.go b/postgres_test.go index e19a595..13a66ac 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -5,6 +5,8 @@ import ( "github.com/go-waitfor/waitfor" "github.com/go-waitfor/waitfor-postgres" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/url" "testing" "time" ) @@ -19,3 +21,73 @@ func TestUse(t *testing.T) { assert.Error(t, err) } + +func TestNew(t *testing.T) { + t.Run("with valid URL", func(t *testing.T) { + u, err := url.Parse("postgres://localhost:5432/testdb") + require.NoError(t, err) + + resource, err := postgres.New(u) + assert.NoError(t, err) + assert.NotNil(t, resource) + }) + + t.Run("with nil URL", func(t *testing.T) { + resource, err := postgres.New(nil) + assert.Error(t, err) + assert.Nil(t, resource) + assert.Contains(t, err.Error(), "url") + assert.Contains(t, err.Error(), "invalid argument") + }) +} + +func TestPostgres_Test(t *testing.T) { + t.Run("with invalid postgres URL", func(t *testing.T) { + u, err := url.Parse("postgres://nonexistent:5432/testdb") + require.NoError(t, err) + + resource, err := postgres.New(u) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err = resource.Test(ctx) + assert.Error(t, err) + }) + + t.Run("with URL that causes connection failure", func(t *testing.T) { + // Use a valid URL format but with invalid host + u, err := url.Parse("postgres://user:password@invalid-host:5432/database") + require.NoError(t, err) + + resource, err := postgres.New(u) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err = resource.Test(ctx) + assert.Error(t, err) + }) + + t.Run("with postgresql scheme (alternate scheme bug test)", func(t *testing.T) { + // This tests a potential bug where using 'postgresql://' scheme + // instead of 'postgres://' could cause sql.Open to fail + u, err := url.Parse("postgresql://user:password@localhost:5432/database") + require.NoError(t, err) + + resource, err := postgres.New(u) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // This should work (or at least fail with connection error, not driver error) + err = resource.Test(ctx) + // We expect a connection error, not a driver registration error + assert.Error(t, err) + // The error should not be about unknown driver + assert.NotContains(t, err.Error(), "unknown driver") + }) +}