-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(connection): add functions to validate the connection string
- Loading branch information
1 parent
77c2705
commit e5f8b5d
Showing
2 changed files
with
289 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
package connection | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"net/url" | ||
"strings" | ||
|
||
"github.com/danvergara/dblab/pkg/command" | ||
) | ||
|
||
var ( | ||
// ErrInvalidUPostgresRLFormat is the error used to notify that the postgres given url is not valid. | ||
ErrInvalidUPostgresRLFormat = errors.New("Invalid URL. Valid format: postgres://user:password@host:port/db?sslmode=mode") | ||
// ErrInvalidUMySQLRLFormat is the error used to notify that the given mysql url is not valid. | ||
ErrInvalidUMySQLRLFormat = errors.New("Invalid URL. Valid format: mysql://user:password@tcp(host:port)/db") | ||
// ErrInvalidURLFormat is used to notify the url is invalid. | ||
ErrInvalidURLFormat = errors.New("invalid url") | ||
) | ||
|
||
// BuildConnectionFromOpts return the connection uri string given the options passed by the uses. | ||
func BuildConnectionFromOpts(opts command.Options) (string, error) { | ||
if opts.URL != "" { | ||
switch opts.Driver { | ||
case "postgres": | ||
return formatPostgresURL(opts) | ||
case "mysql": | ||
return formatMySQLURL(opts) | ||
default: | ||
return "", fmt.Errorf("%s: %w", opts.URL, ErrInvalidURLFormat) | ||
} | ||
} | ||
|
||
return "", nil | ||
} | ||
|
||
// formatPostgresURL returns valid uri for postgres connection. | ||
func formatPostgresURL(opts command.Options) (string, error) { | ||
if !hasValidPosgresPrefix(opts.URL) { | ||
return "", fmt.Errorf("invalid prefix %s : %w", opts.URL, ErrInvalidUPostgresRLFormat) | ||
} | ||
|
||
uri, err := url.Parse(opts.URL) | ||
if err != nil { | ||
return "", fmt.Errorf("%v : %w", err, ErrInvalidUPostgresRLFormat) | ||
} | ||
|
||
result := map[string]string{} | ||
for k, v := range uri.Query() { | ||
result[strings.ToLower(k)] = v[0] | ||
} | ||
|
||
if result["sslmode"] == "" { | ||
if opts.SSL == "" { | ||
if strings.Contains(uri.Host, "localhost") || strings.Contains(uri.Host, "127.0.0.1") { | ||
result["sslmode"] = "disable" | ||
} | ||
} else { | ||
result["sslmode"] = opts.SSL | ||
} | ||
} | ||
|
||
query := url.Values{} | ||
for k, v := range result { | ||
query.Add(k, v) | ||
} | ||
uri.RawQuery = query.Encode() | ||
|
||
return uri.String(), nil | ||
} | ||
|
||
// formatMySQLURL returns valid uri for mysql connection. | ||
func formatMySQLURL(opts command.Options) (string, error) { | ||
if !hasValidMySQLPrefix(opts.URL) { | ||
return "", ErrInvalidUMySQLRLFormat | ||
} | ||
|
||
uri, err := url.Parse(opts.URL) | ||
if err != nil { | ||
return "", ErrInvalidUMySQLRLFormat | ||
} | ||
|
||
result := map[string]string{} | ||
for k, v := range uri.Query() { | ||
result[strings.ToLower(k)] = v[0] | ||
} | ||
|
||
query := url.Values{} | ||
for k, v := range result { | ||
query.Add(k, v) | ||
} | ||
uri.RawQuery = query.Encode() | ||
|
||
return uri.String(), nil | ||
} | ||
|
||
// hasValidPosgresPrefix checks if a given url has the driver name in it. | ||
func hasValidPosgresPrefix(rawurl string) bool { | ||
return strings.HasPrefix(rawurl, "postgres://") || strings.HasPrefix(rawurl, "postgresql://") | ||
} | ||
|
||
// hasValidMySQLPrefix checks if a given url has the driver name in it. | ||
func hasValidMySQLPrefix(rawurl string) bool { | ||
return strings.HasPrefix(rawurl, "mysql://") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
package connection | ||
|
||
import ( | ||
"errors" | ||
"testing" | ||
|
||
"github.com/danvergara/dblab/pkg/command" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestBuildConnectionFromOpts(t *testing.T) { | ||
type given struct { | ||
opts command.Options | ||
} | ||
type want struct { | ||
uri string | ||
hasError bool | ||
err error | ||
} | ||
var cases = []struct { | ||
name string | ||
given given | ||
want want | ||
}{ | ||
{ | ||
name: "valid postgres localhost", | ||
given: given{ | ||
opts: command.Options{ | ||
Driver: "postgres", | ||
URL: "postgres://user:password@localhost:5432/db?sslmode=disable", | ||
}, | ||
}, | ||
want: want{ | ||
uri: "postgres://user:password@localhost:5432/db?sslmode=disable", | ||
}, | ||
}, | ||
{ | ||
name: "valid postgres localhost but add sslmode", | ||
given: given{ | ||
opts: command.Options{ | ||
Driver: "postgres", | ||
URL: "postgres://user:password@localhost:5432/db", | ||
}, | ||
}, | ||
want: want{ | ||
uri: "postgres://user:password@localhost:5432/db?sslmode=disable", | ||
}, | ||
}, | ||
{ | ||
name: "valid postgres localhost postgresql as protocol", | ||
given: given{ | ||
opts: command.Options{ | ||
Driver: "postgres", | ||
URL: "postgresql://user:password@localhost:5432/db", | ||
}, | ||
}, | ||
want: want{ | ||
uri: "postgresql://user:password@localhost:5432/db?sslmode=disable", | ||
}, | ||
}, | ||
{ | ||
name: "error misspelled postgres", | ||
given: given{ | ||
opts: command.Options{ | ||
Driver: "postgres", | ||
URL: "potgre://user:password@localhost:5432/db", | ||
}, | ||
}, | ||
want: want{ | ||
hasError: true, | ||
err: ErrInvalidUPostgresRLFormat, | ||
}, | ||
}, | ||
{ | ||
name: "error invalid url", | ||
given: given{ | ||
opts: command.Options{ | ||
Driver: "postgres", | ||
URL: "not-a-url", | ||
}, | ||
}, | ||
want: want{ | ||
hasError: true, | ||
err: ErrInvalidUPostgresRLFormat, | ||
}, | ||
}, | ||
} | ||
|
||
for _, tc := range cases { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
uri, err := BuildConnectionFromOpts(tc.given.opts) | ||
|
||
if tc.want.hasError { | ||
assert.Error(t, err) | ||
|
||
if !errors.Is(err, tc.want.err) { | ||
t.Errorf("got %v, wanted %v", err, tc.want.err) | ||
} | ||
|
||
return | ||
} | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, tc.want.uri, uri) | ||
}) | ||
} | ||
} | ||
|
||
func TestFormatPostgresURL(t *testing.T) { | ||
type given struct { | ||
opts command.Options | ||
} | ||
type want struct { | ||
uri string | ||
hasError bool | ||
err error | ||
} | ||
var cases = []struct { | ||
name string | ||
given given | ||
want want | ||
}{} | ||
|
||
for _, tc := range cases { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
uri, err := formatPostgresURL(tc.given.opts) | ||
|
||
if tc.want.hasError { | ||
assert.Error(t, err) | ||
|
||
if !errors.Is(err, tc.want.err) { | ||
t.Errorf("got %v, wanted %v", err, tc.want.err) | ||
} | ||
|
||
return | ||
} | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, tc.want.uri, uri) | ||
}) | ||
} | ||
} | ||
|
||
func TestFormatMySQLURL(t *testing.T) { | ||
type given struct { | ||
opts command.Options | ||
} | ||
type want struct { | ||
uri string | ||
hasError bool | ||
err error | ||
} | ||
var cases = []struct { | ||
name string | ||
given given | ||
want want | ||
}{} | ||
|
||
for _, tc := range cases { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
uri, err := formatMySQLURL(tc.given.opts) | ||
|
||
if tc.want.hasError { | ||
assert.Error(t, err) | ||
|
||
if !errors.Is(err, tc.want.err) { | ||
t.Errorf("got %v, wanted %v", err, tc.want.err) | ||
} | ||
|
||
return | ||
} | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, tc.want.uri, uri) | ||
}) | ||
} | ||
} |