From d461eb19031ee2a88c896d372bd110c30f901b93 Mon Sep 17 00:00:00 2001 From: Andrej Kenda Date: Sat, 8 Jul 2023 19:23:56 +0200 Subject: [PATCH] go/clients/mongo,mysql: use url.Parse instead of regex --- dbee/clients/mongo.go | 49 +++++++++---------------------------------- dbee/clients/mysql.go | 18 ++++++++-------- 2 files changed, 19 insertions(+), 48 deletions(-) diff --git a/dbee/clients/mongo.go b/dbee/clients/mongo.go index c149ab2..767388b 100644 --- a/dbee/clients/mongo.go +++ b/dbee/clients/mongo.go @@ -6,7 +6,8 @@ import ( "encoding/json" "errors" "fmt" - "regexp" + "net/url" + "strings" "time" "github.com/kndndrj/nvim-dbee/dbee/clients/common" @@ -44,40 +45,6 @@ func init() { // gob.Register(primitive.Undefined{}) gob.Register(primitive.DBPointer{}) // gob.Register(primitive.Symbol) - -} - -func getDatabaseName(url string) (string, error) { - r, err := regexp.Compile(`mongo.*//(.*:[0-9]+,?)+/(?P.*?)(\?|$)`) - if err != nil { - return "", err - } - - // get submatch index - getSubmatchIndex := func(submatch []string, name string) (int, error) { - for i, n := range submatch { - if n == name { - return i, nil - } - } - return 0, errors.New("no submatch found") - } - i, err := getSubmatchIndex(r.SubexpNames(), "dbname") - if err != nil { - return "", err - } - - // get database name from capture group (with index) - submatch := r.FindStringSubmatch(url) - if len(submatch) < 1 { - return "", errors.New("url doesn't comply to schema") - } - dbName := submatch[i] - if dbName == "" { - return "", errors.New("no dbname found") - } - - return dbName, nil } type MongoClient struct { @@ -85,14 +52,18 @@ type MongoClient struct { dbName string } -func NewMongo(url string) (*MongoClient, error) { +func NewMongo(rawURL string) (*MongoClient, error) { // get database name from url - dbName, err := getDatabaseName(url) + u, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("mongo: invalid url: %v", err) + return nil, fmt.Errorf("mongo: invalid url: %w", err) + } + dbName := strings.TrimPrefix(u.Path, "/") + if dbName == "" { + return nil, fmt.Errorf("mongo: url doesn't comply to schema: database name must be set") } - opts := options.Client().ApplyURI(url) + opts := options.Client().ApplyURI(rawURL) client, err := mongo.Connect(context.TODO(), opts) if err != nil { return nil, err diff --git a/dbee/clients/mysql.go b/dbee/clients/mysql.go index cbd772d..13c081f 100644 --- a/dbee/clients/mysql.go +++ b/dbee/clients/mysql.go @@ -3,7 +3,7 @@ package clients import ( "database/sql" "fmt" - "regexp" + "net/url" _ "github.com/go-sql-driver/mysql" "github.com/kndndrj/nvim-dbee/dbee/clients/common" @@ -23,19 +23,19 @@ type MysqlClient struct { sql *common.Client } -func NewMysql(url string) (*MysqlClient, error) { +func NewMysql(rawURL string) (*MysqlClient, error) { // add multiple statements support parameter - match, err := regexp.MatchString(`[\?][\w]+=[\w-]+`, url) + u, err := url.Parse(rawURL) if err != nil { - return nil, err + return nil, fmt.Errorf("mysql: invalid url: %w", err) } - if match { - url = url + "&multiStatements=true" - } else { - url = url + "?multiStatements=true" + q := u.Query() + if !q.Has("multiStatements") { + q.Add("multiStatements", "true") } + u.RawQuery = q.Encode() - db, err := sql.Open("mysql", url) + db, err := sql.Open("mysql", rawURL) if err != nil { return nil, fmt.Errorf("unable to connect to mysql database: %v", err) }