From 85e22fd90ba603719c6b3c37858876dd08c1db8f Mon Sep 17 00:00:00 2001 From: Phil Eaton Date: Wed, 22 Dec 2021 19:08:02 -0500 Subject: [PATCH] Fix for remote connections in Go runner (#134) * Fix remote db open * Update arch * Fix for remote db connection * Fix for fmt * Fix for tests --- ARCHITECTURE.md | 8 ++++++ runner/database.go | 61 ++++++++++++++++++++++------------------- runner/database_test.go | 14 +++++++++- runner/ssh.go | 14 ++++++---- 4 files changed, 63 insertions(+), 34 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 2b047c126..80ee0270d 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -33,6 +33,10 @@ something to install all missing dependencies. ### ./desktop/panel +NOTE: This code is being migrated to Go. All panel types except for a +few database vendors have been ported to Go. A number of Node panel +handlers have been deleted since they are no longer used. + This is where eval handlers for each panel type (program, database, etc.) are defined. @@ -45,6 +49,10 @@ this on desktop. ./server/runner.ts is the equivalent on the server. This allows easy resource cleanup and easy "kill" panel eval support. +## ./runner + +This is where the Go port of the original Node.js panel eval code is. + ## ./server This directory contains the server (Express) app and code that proxies diff --git a/runner/database.go b/runner/database.go index 8e502c178..f4ceb8988 100644 --- a/runner/database.go +++ b/runner/database.go @@ -24,16 +24,22 @@ import ( _ "github.com/snowflakedb/gosnowflake" ) -func getDatabaseHostPort(raw, defaultPort string) (string, string, error) { - beforeQuery := strings.Split(raw, "?")[0] +func getDatabaseHostPortExtra(raw, defaultPort string) (string, string, string, error) { + addressAndArgs := strings.SplitN(raw, "?", 2) + extra := "" + beforeQuery := addressAndArgs[0] + if len(addressAndArgs) > 1 { + extra = addressAndArgs[1] + } _, _, err := net.SplitHostPort(beforeQuery) if err != nil && strings.HasSuffix(err.Error(), "missing port in address") { beforeQuery += ":" + defaultPort } else if err != nil { - return "", "", edsef("Could not split host-port: %s", err) + return "", "", "", edsef("Could not split host-port: %s", err) } - return net.SplitHostPort(beforeQuery) + host, port, err := net.SplitHostPort(beforeQuery) + return host, port, extra, err } func debugObject(obj interface{}) { @@ -97,7 +103,7 @@ var defaultPorts = map[DatabaseConnectorInfoType]string{ func getConnectionString(dbInfo DatabaseConnectorInfoDatabase) (string, string, error) { address := dbInfo.Address - split := strings.Split(address, "?") + split := strings.SplitN(address, "?", 2) address = split[0] extraArgs := "" if len(split) > 1 { @@ -272,16 +278,6 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p dbInfo := connector.Database - vendor, connStr, err := getConnectionString(dbInfo) - if err != nil { - return err - } - - db, err := sqlx.Open(vendor, connStr) - if err != nil { - return err - } - mangleInsert := defaultMangleInsert qt := ansiSQLQuote if dbInfo.Type == "postgres" { @@ -306,16 +302,11 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p return err } - // Require queries end with semicolon primarily for Oracle - // that blows up without this. This will still blow up if - // there's no semicolon and there are comments. - // e.g. `SELECT 1 -- flubber` -> `SELECT 1 -- flubber;` - //qWithoutWs := strings.TrimSpace(query) - //if qWithoutWs[len(qWithoutWs)-1] != ';' { - // query += ";" - //} - - server, err := getServer(project, panel.ServerId) + serverId := panel.ServerId + if serverId == "" { + serverId = connector.ServerId + } + server, err := getServer(project, serverId) if err != nil { return err } @@ -338,10 +329,9 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p } dbInfo.Database = tmp.Name() - tmp.Close() } - host, port, err := getDatabaseHostPort(dbInfo.Address, defaultPorts[dbInfo.Type]) + host, port, extra, err := getDatabaseHostPortExtra(dbInfo.Address, defaultPorts[dbInfo.Type]) if err != nil { return err } @@ -360,7 +350,21 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p } defer w.Close() - return withRemoteConnection(server, host, port, func(host, port string) error { + return withRemoteConnection(server, host, port, func(proxyHost, proxyPort string) error { + dbInfo.Address = proxyHost + ":" + proxyPort + if extra != "" { + dbInfo.Address += "?" + extra + } + vendor, connStr, err := getConnectionString(dbInfo) + if err != nil { + return err + } + + db, err := sqlx.Open(vendor, connStr) + if err != nil { + return err + } + wroteFirstRow := false return withJSONArrayOutWriterFile(w, func(w *JSONArrayWriter) error { _, err := importAndRun( @@ -377,6 +381,7 @@ func EvalDatabasePanel(project *ProjectState, pageIndex int, panel *PanelInfo, p if err != nil { return nil, err } + defer rows.Close() for rows.Next() { diff --git a/runner/database_test.go b/runner/database_test.go index 755765d0f..282771964 100644 --- a/runner/database_test.go +++ b/runner/database_test.go @@ -14,6 +14,7 @@ func Test_getConnectionString(t *testing.T) { expErr error expHost string expPort string + expExtra string }{ { DatabaseConnectorInfoDatabase{Type: "postgres", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost?sslmode=disable"}, @@ -22,6 +23,7 @@ func Test_getConnectionString(t *testing.T) { nil, "localhost", "5432", + "sslmode=disable", }, { DatabaseConnectorInfoDatabase{Type: "postgres", Database: "test", Address: "big.com:8888?sslmode=disable"}, @@ -30,6 +32,7 @@ func Test_getConnectionString(t *testing.T) { nil, "big.com", "8888", + "sslmode=disable", }, { DatabaseConnectorInfoDatabase{Type: "mysql", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost:9090"}, @@ -38,6 +41,7 @@ func Test_getConnectionString(t *testing.T) { nil, "localhost", "9090", + "", }, { DatabaseConnectorInfoDatabase{Type: "sqlite", Database: "test.sql"}, @@ -46,6 +50,7 @@ func Test_getConnectionString(t *testing.T) { nil, "", "", + "", }, { DatabaseConnectorInfoDatabase{Type: "oracle", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost"}, @@ -54,6 +59,7 @@ func Test_getConnectionString(t *testing.T) { nil, "localhost", "1521", + "", }, { DatabaseConnectorInfoDatabase{Type: "snowflake", Username: "jim", Password: Encrypt{Encrypted: false, Value: ""}, Database: "test", Address: "myid"}, @@ -62,6 +68,7 @@ func Test_getConnectionString(t *testing.T) { nil, "", "", + "", }, { DatabaseConnectorInfoDatabase{Type: "snowflake", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "myid?x=y"}, @@ -70,6 +77,7 @@ func Test_getConnectionString(t *testing.T) { nil, "", "", + "x=y", }, { DatabaseConnectorInfoDatabase{Type: "sqlserver", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost"}, @@ -78,6 +86,7 @@ func Test_getConnectionString(t *testing.T) { nil, "localhost", "1433", + "", }, { DatabaseConnectorInfoDatabase{Type: "clickhouse", Username: "jim", Password: Encrypt{Encrypted: false, Value: "pw"}, Database: "test", Address: "localhost"}, @@ -86,6 +95,7 @@ func Test_getConnectionString(t *testing.T) { nil, "localhost", "9000", + "", }, { DatabaseConnectorInfoDatabase{Type: "clickhouse", Password: Encrypt{Encrypted: false, Value: ""}, Database: "test", Address: "localhost:9001"}, @@ -94,6 +104,7 @@ func Test_getConnectionString(t *testing.T) { nil, "localhost", "9001", + "", }, } for _, test := range tests { @@ -106,9 +117,10 @@ func Test_getConnectionString(t *testing.T) { continue } - host, port, err := getDatabaseHostPort(test.conn.Address, defaultPorts[DatabaseConnectorInfoType(test.expVendor)]) + host, port, extra, err := getDatabaseHostPortExtra(test.conn.Address, defaultPorts[DatabaseConnectorInfoType(test.expVendor)]) assert.Nil(t, err) assert.Equal(t, test.expHost, host) assert.Equal(t, test.expPort, port) + assert.Equal(t, test.expExtra, extra) } } diff --git a/runner/ssh.go b/runner/ssh.go index b6e8ea8ea..60dc63d24 100644 --- a/runner/ssh.go +++ b/runner/ssh.go @@ -254,13 +254,17 @@ func withRemoteConnection(si *ServerInfo, host, port string, cb func(host, port localPort := localConn.Addr().(*net.TCPAddr).Port cbErr := cb("localhost", fmt.Sprintf("%d", localPort)) if cbErr != nil { - return err + return cbErr } - err = <-errC - if err == io.EOF { + select { + case err = <-errC: + if err == io.EOF { + return nil + } + + return err + default: return nil } - - return err }