Skip to content

Commit

Permalink
cli: add support for multiple results to SQL statements
Browse files Browse the repository at this point in the history
Use the new functionality in lib/pq to select the next set of results
when multiple statements were executed.

Fixes #4016.
  • Loading branch information
petermattis committed Feb 25, 2016
1 parent b155c6a commit c0c963e
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 119 deletions.
1 change: 1 addition & 0 deletions GLOCKFILE
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ github.com/cockroachdb/c-lz4 c40aaae2fc50293eb8750b34632bc3efe813e23f
github.com/cockroachdb/c-protobuf 4feb192131ea08dfbd7253a00868ad69cbb61b81
github.com/cockroachdb/c-rocksdb c0124c907c74b579d9d3d48eb96471bef270bc25
github.com/cockroachdb/c-snappy 5c6d0932e0adaffce4bfca7bdf2ac37f79952ccf
github.com/cockroachdb/pq 77893094b774b29f293681e6ac0a9322fbf3ce25
github.com/cockroachdb/stress aa7690c22fd0abd6168ed0e6c361e4f4c5f7ab25
github.com/codahale/hdrhistogram e88be87d51429689cef99043a54150d733265cd7
github.com/coreos/etcd 410d32a9b14f6052a834a966a02950cde518d7ce
Expand Down
8 changes: 8 additions & 0 deletions cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ func Example_sql() {
c.RunWithArgs([]string{"sql", "-e", "select * from t.f"})
c.RunWithArgs([]string{"sql", "-e", "show databases"})
c.RunWithArgs([]string{"sql", "-e", "explain select 3"})
c.RunWithArgs([]string{"sql", "-e", "select 1; select 2"})

// Output:
// sql -e create database t; create table t.f (x int, y int); insert into t.f values (42, 69)
Expand Down Expand Up @@ -568,6 +569,13 @@ func Example_sql() {
// 1 row
// Level Type Description
// 0 empty -
// sql -e select 1; select 2
// 1 row
// 1
// 1
// 1 row
// 2
// 2
}

func Example_sql_escape() {
Expand Down
5 changes: 2 additions & 3 deletions cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ subsequent positional argument on the command line may contain
one or more SQL statements, separated by semicolons. If an
error occurs in any statement, the command exits with a
non-zero status code and further statements are not
executed. Only the results of the first SQL statement in each
positional argument are printed on the standard output.`),

executed. The results of each SQL statement are printed on
the standard output.`),
"join": wrapText(`
A comma-separated list of addresses to use when a new node is joining
an existing cluster. For the first node in a cluster, --join should
Expand Down
52 changes: 30 additions & 22 deletions cli/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"github.com/cockroachdb/cockroach/util/log"
"github.com/mattn/go-isatty"
"github.com/spf13/cobra"

"github.com/cockroachdb/pq"
)

const (
Expand Down Expand Up @@ -210,7 +212,7 @@ func runInteractive(conn *sqlConn) (exitErr error) {
readline.SetHistoryPath("")
}

if exitErr = runPrettyQuery(conn, os.Stdout, fullStmt); exitErr != nil {
if exitErr = runPrettyQuery(conn, os.Stdout, makeQuery(fullStmt)); exitErr != nil {
fmt.Fprintln(osStderr, exitErr)
}

Expand All @@ -225,30 +227,36 @@ func runInteractive(conn *sqlConn) (exitErr error) {
// on error.
func runStatements(conn *sqlConn, stmts []string) error {
for _, stmt := range stmts {
fullStmt := stmt + "\n"
cols, allRows, err := runQuery(conn, fullStmt)
if err != nil {
fmt.Fprintln(osStderr, err)
return err
}

if len(cols) == 0 {
// No result selected, inform the user.
fmt.Fprintln(os.Stdout, "OK")
} else {
// Some results selected, inform the user about how much data to expect.
noun := "rows"
if len(allRows) == 1 {
noun = "row"
q := makeQuery(stmt)
for {
cols, allRows, err := runQuery(conn, q)
if err != nil {
if err == pq.ErrNoMoreResults {
break
}
fmt.Fprintln(osStderr, err)
os.Exit(1)
}

fmt.Fprintf(os.Stdout, "%d %s\n", len(allRows), noun)

// Then print the results themselves.
fmt.Fprintln(os.Stdout, strings.Join(cols, "\t"))
for _, row := range allRows {
fmt.Fprintln(os.Stdout, strings.Join(row, "\t"))
if len(cols) == 0 {
// No result selected, inform the user.
fmt.Fprintln(os.Stdout, "OK")
} else {
// Some results selected, inform the user about how much data to expect.
noun := "rows"
if len(allRows) == 1 {
noun = "row"
}

fmt.Fprintf(os.Stdout, "%d %s\n", len(allRows), noun)

// Then print the results themselves.
fmt.Fprintln(os.Stdout, strings.Join(cols, "\t"))
for _, row := range allRows {
fmt.Fprintln(os.Stdout, strings.Join(row, "\t"))
}
}
q = nextResult
}
}
return nil
Expand Down
89 changes: 59 additions & 30 deletions cli/sql_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ import (
"fmt"
"io"

"github.com/lib/pq"

"github.com/olekukonko/tablewriter"

"github.com/cockroachdb/cockroach/util/log"
"github.com/cockroachdb/pq"
)

type sqlConnI interface {
driver.Conn
driver.Queryer
Next() (driver.Rows, error)
}

type sqlConn struct {
Expand Down Expand Up @@ -64,6 +64,20 @@ func (c *sqlConn) Query(query string, args []driver.Value) (*sqlRows, error) {
return &sqlRows{Rows: rows, conn: c}, nil
}

func (c *sqlConn) Next() (*sqlRows, error) {
if c.conn == nil {
return nil, driver.ErrBadConn
}
rows, err := c.conn.Next()
if err == driver.ErrBadConn {
c.Close()
}
if err != nil {
return nil, err
}
return &sqlRows{Rows: rows, conn: c}, nil
}

func (c *sqlConn) Close() {
if c.conn != nil {
err := c.conn.Close()
Expand Down Expand Up @@ -113,51 +127,66 @@ func makeSQLClient() *sqlConn {
// and outputs the string to be displayed.
type fmtMap map[string]func(driver.Value) string

type queryFunc func(conn *sqlConn) (*sqlRows, error)

func nextResult(conn *sqlConn) (*sqlRows, error) {
return conn.Next()
}

func makeQuery(query string, parameters ...driver.Value) queryFunc {
return func(conn *sqlConn) (*sqlRows, error) {
// driver.Value is an alias for interface{}, but must adhere to a restricted
// set of types when being passed to driver.Queryer.Query (see
// driver.IsValue). We use driver.DefaultParameterConverter to perform the
// necessary conversion. This is usually taken care of by the sql package,
// but we have to do so manually because we're talking directly to the
// driver.
for i := range parameters {
var err error
parameters[i], err = driver.DefaultParameterConverter.ConvertValue(parameters[i])
if err != nil {
return nil, err
}
}
return conn.Query(query, parameters)
}
}

// runQuery takes a 'query' with optional 'parameters'.
// It runs the sql query and returns a list of columns names and a list of rows.
func runQuery(db *sqlConn, query string, parameters ...driver.Value) (
[]string, [][]string, error) {
return runQueryWithFormat(db, nil, query, parameters...)
func runQuery(conn *sqlConn, fn queryFunc) ([]string, [][]string, error) {
return runQueryWithFormat(conn, nil, fn)
}

// runQuery takes a 'query' with optional 'parameters'.
// runQueryWithFormat takes a 'query' with optional 'parameters'.
// It runs the sql query and returns a list of columns names and a list of rows.
// If 'format' is not nil, the values with column name
// found in the map are run through the corresponding callback.
func runQueryWithFormat(db *sqlConn, format fmtMap, query string, parameters ...driver.Value) (
func runQueryWithFormat(conn *sqlConn, format fmtMap, fn queryFunc) (
[]string, [][]string, error) {
// driver.Value is an alias for interface{}, but must adhere to a restricted
// set of types when being passed to driver.Queryer.Query (see
// driver.IsValue). We use driver.DefaultParameterConverter to perform the
// necessary conversion. This is usually taken care of by the sql package,
// but we have to do so manually because we're talking directly to the
// driver.
for i := range parameters {
var err error
parameters[i], err = driver.DefaultParameterConverter.ConvertValue(parameters[i])
if err != nil {
return nil, nil, err
}
}

rows, err := db.Query(query, parameters)
rows, err := fn(conn)
if err != nil {
return nil, nil, fmt.Errorf("query error: %s", err)
return nil, nil, err
}

defer func() { _ = rows.Close() }()
return sqlRowsToStrings(rows, format)
}

// runPrettyQueryWithFormat takes a 'query' with optional 'parameters'.
// runPrettyQuery takes a 'query' with optional 'parameters'.
// It runs the sql query and writes pretty output to 'w'.
func runPrettyQuery(db *sqlConn, w io.Writer, query string, parameters ...driver.Value) error {
cols, allRows, err := runQuery(db, query, parameters...)
if err != nil {
return err
func runPrettyQuery(conn *sqlConn, w io.Writer, fn queryFunc) error {
for {
cols, allRows, err := runQuery(conn, fn)
if err != nil {
if err == pq.ErrNoMoreResults {
return nil
}
return err
}
printQueryOutput(w, cols, allRows)
fn = nextResult
}
printQueryOutput(w, cols, allRows)
return nil
}

// sqlRowsToStrings turns 'rows' into a list of rows, each of which
Expand Down
74 changes: 35 additions & 39 deletions cli/sql_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestRunQuery(t *testing.T) {
var b bytes.Buffer

// Non-query statement.
if err := runPrettyQuery(conn, &b, `SET DATABASE=system`); err != nil {
if err := runPrettyQuery(conn, &b, makeQuery(`SET DATABASE=system`)); err != nil {
t.Fatal(err)
}

Expand All @@ -57,7 +57,7 @@ OK
b.Reset()

// Use system database for sample query/output as they are fairly fixed.
cols, rows, err := runQuery(conn, `SHOW COLUMNS FROM system.namespace`)
cols, rows, err := runQuery(conn, makeQuery(`SHOW COLUMNS FROM system.namespace`))
if err != nil {
t.Fatal(err)
}
Expand All @@ -76,7 +76,8 @@ OK
t.Fatalf("expected:\n%v\ngot:\n%v", expectedRows, rows)
}

if err := runPrettyQuery(conn, &b, `SHOW COLUMNS FROM system.namespace`); err != nil {
if err := runPrettyQuery(conn, &b,
makeQuery(`SHOW COLUMNS FROM system.namespace`)); err != nil {
t.Fatal(err)
}

Expand All @@ -96,7 +97,8 @@ OK
b.Reset()

// Test placeholders.
if err := runPrettyQuery(conn, &b, `SELECT * FROM system.namespace WHERE name=$1`, "descriptor"); err != nil {
if err := runPrettyQuery(conn, &b,
makeQuery(`SELECT * FROM system.namespace WHERE name=$1`, "descriptor")); err != nil {
t.Fatal(err)
}

Expand All @@ -118,7 +120,7 @@ OK
}

_, rows, err = runQueryWithFormat(conn, fmtMap{"name": newFormat},
`SELECT * FROM system.namespace WHERE name=$1`, "descriptor")
makeQuery(`SELECT * FROM system.namespace WHERE name=$1`, "descriptor"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -129,38 +131,32 @@ OK
}
b.Reset()

// TODO(pmattis): This test case fails now as lib/pq doesn't handle multiple
// results correctly. We were previously incorrectly ignoring the error from
// sql.Rows.Err() which is what allowed the test to pass.

/**
// Test multiple results.
if err := runPrettyQuery(conn, &b, `SELECT 1; SELECT 2, 3; SELECT 'hello'`); err != nil {
t.Fatal(err)
}
expected = `
+---+
| 1 |
+---+
| 1 |
+---+
`
// TODO(pmattis): When #4016 is fixed, we should see:
// +---+---+
// | 2 | 3 |
// +---+---+
// | 2 | 3 |
// +---+---+
// +---------+
// | 'hello' |
// +---------+
// | "hello" |
// +---------+
if a, e := b.String(), expected[1:]; a != e {
t.Fatalf("expected output:\n%s\ngot:\n%s", e, a)
}
b.Reset()
**/
// Test multiple results.
if err := runPrettyQuery(conn, &b,
makeQuery(`SELECT 1; SELECT 2, 3; SELECT 'hello'`)); err != nil {
t.Fatal(err)
}

expected = `
+---+
| 1 |
+---+
| 1 |
+---+
+---+---+
| 2 | 3 |
+---+---+
| 2 | 3 |
+---+---+
+---------+
| 'hello' |
+---------+
| hello |
+---------+
`

if a, e := b.String(), expected[1:]; a != e {
t.Fatalf("expected output:\n%s\ngot:\n%s", e, a)
}
b.Reset()
}
Loading

0 comments on commit c0c963e

Please sign in to comment.