Skip to content

Commit

Permalink
feat(connection): add functions to validate the connection string
Browse files Browse the repository at this point in the history
  • Loading branch information
danvergara committed Apr 14, 2021
1 parent 77c2705 commit e5f8b5d
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 0 deletions.
105 changes: 105 additions & 0 deletions pkg/connection/connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package connection

import (
"errors"
"fmt"
"net/url"
"strings"

"github.com/danvergara/dblab/pkg/command"
)

var (
// ErrInvalidUPostgresRLFormat is the error used to notify that the postgres given url is not valid.
ErrInvalidUPostgresRLFormat = errors.New("Invalid URL. Valid format: postgres://user:password@host:port/db?sslmode=mode")
// ErrInvalidUMySQLRLFormat is the error used to notify that the given mysql url is not valid.
ErrInvalidUMySQLRLFormat = errors.New("Invalid URL. Valid format: mysql://user:password@tcp(host:port)/db")
// ErrInvalidURLFormat is used to notify the url is invalid.
ErrInvalidURLFormat = errors.New("invalid url")
)

// BuildConnectionFromOpts return the connection uri string given the options passed by the uses.
func BuildConnectionFromOpts(opts command.Options) (string, error) {
if opts.URL != "" {
switch opts.Driver {
case "postgres":
return formatPostgresURL(opts)
case "mysql":
return formatMySQLURL(opts)
default:
return "", fmt.Errorf("%s: %w", opts.URL, ErrInvalidURLFormat)
}
}

return "", nil
}

// formatPostgresURL returns valid uri for postgres connection.
func formatPostgresURL(opts command.Options) (string, error) {
if !hasValidPosgresPrefix(opts.URL) {
return "", fmt.Errorf("invalid prefix %s : %w", opts.URL, ErrInvalidUPostgresRLFormat)
}

uri, err := url.Parse(opts.URL)
if err != nil {
return "", fmt.Errorf("%v : %w", err, ErrInvalidUPostgresRLFormat)
}

result := map[string]string{}
for k, v := range uri.Query() {
result[strings.ToLower(k)] = v[0]
}

if result["sslmode"] == "" {
if opts.SSL == "" {
if strings.Contains(uri.Host, "localhost") || strings.Contains(uri.Host, "127.0.0.1") {
result["sslmode"] = "disable"
}
} else {
result["sslmode"] = opts.SSL
}
}

query := url.Values{}
for k, v := range result {
query.Add(k, v)
}
uri.RawQuery = query.Encode()

return uri.String(), nil
}

// formatMySQLURL returns valid uri for mysql connection.
func formatMySQLURL(opts command.Options) (string, error) {
if !hasValidMySQLPrefix(opts.URL) {
return "", ErrInvalidUMySQLRLFormat
}

uri, err := url.Parse(opts.URL)
if err != nil {
return "", ErrInvalidUMySQLRLFormat
}

result := map[string]string{}
for k, v := range uri.Query() {
result[strings.ToLower(k)] = v[0]
}

query := url.Values{}
for k, v := range result {
query.Add(k, v)
}
uri.RawQuery = query.Encode()

return uri.String(), nil
}

// hasValidPosgresPrefix checks if a given url has the driver name in it.
func hasValidPosgresPrefix(rawurl string) bool {
return strings.HasPrefix(rawurl, "postgres://") || strings.HasPrefix(rawurl, "postgresql://")
}

// hasValidMySQLPrefix checks if a given url has the driver name in it.
func hasValidMySQLPrefix(rawurl string) bool {
return strings.HasPrefix(rawurl, "mysql://")
}
184 changes: 184 additions & 0 deletions pkg/connection/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package connection

import (
"errors"
"testing"

"github.com/danvergara/dblab/pkg/command"
"github.com/stretchr/testify/assert"
)

func TestBuildConnectionFromOpts(t *testing.T) {
type given struct {
opts command.Options
}
type want struct {
uri string
hasError bool
err error
}
var cases = []struct {
name string
given given
want want
}{
{
name: "valid postgres localhost",
given: given{
opts: command.Options{
Driver: "postgres",
URL: "postgres://user:password@localhost:5432/db?sslmode=disable",
},
},
want: want{
uri: "postgres://user:password@localhost:5432/db?sslmode=disable",
},
},
{
name: "valid postgres localhost but add sslmode",
given: given{
opts: command.Options{
Driver: "postgres",
URL: "postgres://user:password@localhost:5432/db",
},
},
want: want{
uri: "postgres://user:password@localhost:5432/db?sslmode=disable",
},
},
{
name: "valid postgres localhost postgresql as protocol",
given: given{
opts: command.Options{
Driver: "postgres",
URL: "postgresql://user:password@localhost:5432/db",
},
},
want: want{
uri: "postgresql://user:password@localhost:5432/db?sslmode=disable",
},
},
{
name: "error misspelled postgres",
given: given{
opts: command.Options{
Driver: "postgres",
URL: "potgre://user:password@localhost:5432/db",
},
},
want: want{
hasError: true,
err: ErrInvalidUPostgresRLFormat,
},
},
{
name: "error invalid url",
given: given{
opts: command.Options{
Driver: "postgres",
URL: "not-a-url",
},
},
want: want{
hasError: true,
err: ErrInvalidUPostgresRLFormat,
},
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

uri, err := BuildConnectionFromOpts(tc.given.opts)

if tc.want.hasError {
assert.Error(t, err)

if !errors.Is(err, tc.want.err) {
t.Errorf("got %v, wanted %v", err, tc.want.err)
}

return
}

assert.NoError(t, err)
assert.Equal(t, tc.want.uri, uri)
})
}
}

func TestFormatPostgresURL(t *testing.T) {
type given struct {
opts command.Options
}
type want struct {
uri string
hasError bool
err error
}
var cases = []struct {
name string
given given
want want
}{}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
uri, err := formatPostgresURL(tc.given.opts)

if tc.want.hasError {
assert.Error(t, err)

if !errors.Is(err, tc.want.err) {
t.Errorf("got %v, wanted %v", err, tc.want.err)
}

return
}

assert.NoError(t, err)
assert.Equal(t, tc.want.uri, uri)
})
}
}

func TestFormatMySQLURL(t *testing.T) {
type given struct {
opts command.Options
}
type want struct {
uri string
hasError bool
err error
}
var cases = []struct {
name string
given given
want want
}{}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
uri, err := formatMySQLURL(tc.given.opts)

if tc.want.hasError {
assert.Error(t, err)

if !errors.Is(err, tc.want.err) {
t.Errorf("got %v, wanted %v", err, tc.want.err)
}

return
}

assert.NoError(t, err)
assert.Equal(t, tc.want.uri, uri)
})
}
}

0 comments on commit e5f8b5d

Please sign in to comment.