Skip to content

Add --database flag to sqlcmd query #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/modern/root/install/mssql-base.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
Password: secret.Encode(saPassword, c.encryptPassword)},
Name: "sa"}

c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Interactive: false})
c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Database: "master", Interactive: false})

c.createNonSaUser(userName, password)

Expand Down
25 changes: 21 additions & 4 deletions cmd/modern/root/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
package root

import (
"fmt"
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
"github.com/microsoft/go-sqlcmd/internal/config"
"github.com/microsoft/go-sqlcmd/internal/pal"
"github.com/microsoft/go-sqlcmd/internal/sql"
)

// Query defines the `sqlcmd query` command
type Query struct {
cmdparser.Cmd

text string
text string
database string
}

func (c *Query) DefineCommand(...cmdparser.CommandOptions) {
Expand All @@ -25,7 +28,15 @@ func (c *Query) DefineCommand(...cmdparser.CommandOptions) {
`sqlcmd query "SELECT @@SERVERNAME"`,
`sqlcmd query --text "SELECT @@SERVERNAME"`,
`sqlcmd query --query "SELECT @@SERVERNAME"`,
}}},
}},
{Description: "Run a query using [master] database", Steps: []string{
`sqlcmd query "SELECT DB_NAME()" --database master`,
}},
{Description: "Set new default database", Steps: []string{
fmt.Sprintf(`sqlcmd query "ALTER LOGIN [%s] WITH DEFAULT_DATABASE = [tempdb]" --database master`,
pal.UserName()),
}},
},
Run: c.run,
FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{
Flag: "text",
Expand All @@ -47,6 +58,12 @@ func (c *Query) DefineCommand(...cmdparser.CommandOptions) {
Name: "query",
Shorthand: "q",
Usage: "Command text to run"})

c.AddFlag(cmdparser.FlagOptions{
String: &c.database,
Name: "database",
Shorthand: "d",
Usage: "Database to use"})
}

// run executes the Query command.
Expand All @@ -58,9 +75,9 @@ func (c *Query) run() {

s := sql.New(sql.SqlOptions{})
if c.text == "" {
s.Connect(endpoint, user, sql.ConnectOptions{Interactive: true})
s.Connect(endpoint, user, sql.ConnectOptions{Database: c.database, Interactive: true})
} else {
s.Connect(endpoint, user, sql.ConnectOptions{Interactive: false})
s.Connect(endpoint, user, sql.ConnectOptions{Database: c.database, Interactive: false})
}

s.Query(c.text)
Expand Down
25 changes: 19 additions & 6 deletions cmd/modern/root/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"github.com/microsoft/go-sqlcmd/cmd/modern/root/config"
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
"github.com/stretchr/testify/assert"
"os"
"runtime"
"testing"
Expand All @@ -18,8 +17,27 @@ func TestQuery(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)")
}

cmdparser.TestSetup(t)

setupContext(t)
cmdparser.TestCmd[*Query]("PRINT")
}

func TestQueryWithNonDefaultDatabase(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)")
}

cmdparser.TestSetup(t)

setupContext(t)
cmdparser.TestCmd[*Query](`--text "PRINT DB_NAME()" --database master`)

// TODO: Add test validation that DB name was actually master!
}

func setupContext(t *testing.T) {
// if SQLCMDSERVER != "" add an endpoint using the --address
if os.Getenv("SQLCMDSERVER") == "" {
cmdparser.TestCmd[*config.AddEndpoint]()
Expand All @@ -33,10 +51,6 @@ func TestQuery(t *testing.T) {
if os.Getenv("SQLCMDPASSWORD") != "" &&
os.Getenv("SQLCMDUSER") != "" {

// sqlcmd uses the SQLCMD_PASSWORD env var, but the tests use the
// SQLCMDPASSWORD env var
err := os.Setenv("SQLCMD_PASSWORD", os.Getenv("SQLCMDPASSWORD"))
assert.Nil(t, err)
cmdparser.TestCmd[*config.AddUser](
fmt.Sprintf("--name user1 --username %s",
os.Getenv("SQLCMDUSER")))
Expand All @@ -45,5 +59,4 @@ func TestQuery(t *testing.T) {
cmdparser.TestCmd[*config.AddContext]("--endpoint endpoint")
}
cmdparser.TestCmd[*config.View]() // displaying the config (info in-case test fails)
cmdparser.TestCmd[*Query]("PRINT")
}
2 changes: 2 additions & 0 deletions internal/sql/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ type Sql interface {
}

type ConnectOptions struct {
Database string

Interactive bool
}
4 changes: 4 additions & 0 deletions internal/sql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func (m *mssql) Connect(
ApplicationName: "sqlcmd",
}

if options.Database != "" {
connect.Database = options.Database
}

if user == nil {
connect.UseTrustedConnection = true
} else {
Expand Down