diff --git a/cmd/modern/root/install/mssql-base.go b/cmd/modern/root/install/mssql-base.go index 9a17afd5..e5c32ad0 100644 --- a/cmd/modern/root/install/mssql-base.go +++ b/cmd/modern/root/install/mssql-base.go @@ -326,7 +326,7 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) { Password: secret.Encode(saPassword, c.encryptPassword)}, Name: "sa"} - c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Interactive: false}) + c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Database: "master", Interactive: false}) c.createNonSaUser(userName, password) diff --git a/cmd/modern/root/query.go b/cmd/modern/root/query.go index 684f4dc3..86fb62a9 100644 --- a/cmd/modern/root/query.go +++ b/cmd/modern/root/query.go @@ -4,8 +4,10 @@ package root import ( + "fmt" "github.com/microsoft/go-sqlcmd/internal/cmdparser" "github.com/microsoft/go-sqlcmd/internal/config" + "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/microsoft/go-sqlcmd/internal/sql" ) @@ -13,7 +15,8 @@ import ( type Query struct { cmdparser.Cmd - text string + text string + database string } func (c *Query) DefineCommand(...cmdparser.CommandOptions) { @@ -25,7 +28,15 @@ func (c *Query) DefineCommand(...cmdparser.CommandOptions) { `sqlcmd query "SELECT @@SERVERNAME"`, `sqlcmd query --text "SELECT @@SERVERNAME"`, `sqlcmd query --query "SELECT @@SERVERNAME"`, - }}}, + }}, + {Description: "Run a query using [master] database", Steps: []string{ + `sqlcmd query "SELECT DB_NAME()" --database master`, + }}, + {Description: "Set new default database", Steps: []string{ + fmt.Sprintf(`sqlcmd query "ALTER LOGIN [%s] WITH DEFAULT_DATABASE = [tempdb]" --database master`, + pal.UserName()), + }}, + }, Run: c.run, FirstArgAlternativeForFlag: &cmdparser.AlternativeForFlagOptions{ Flag: "text", @@ -47,6 +58,12 @@ func (c *Query) DefineCommand(...cmdparser.CommandOptions) { Name: "query", Shorthand: "q", Usage: "Command text to run"}) + + c.AddFlag(cmdparser.FlagOptions{ + String: &c.database, + Name: "database", + Shorthand: "d", + Usage: "Database to use"}) } // run executes the Query command. @@ -58,9 +75,9 @@ func (c *Query) run() { s := sql.New(sql.SqlOptions{}) if c.text == "" { - s.Connect(endpoint, user, sql.ConnectOptions{Interactive: true}) + s.Connect(endpoint, user, sql.ConnectOptions{Database: c.database, Interactive: true}) } else { - s.Connect(endpoint, user, sql.ConnectOptions{Interactive: false}) + s.Connect(endpoint, user, sql.ConnectOptions{Database: c.database, Interactive: false}) } s.Query(c.text) diff --git a/cmd/modern/root/query_test.go b/cmd/modern/root/query_test.go index 7d1d2289..496cead4 100644 --- a/cmd/modern/root/query_test.go +++ b/cmd/modern/root/query_test.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/microsoft/go-sqlcmd/cmd/modern/root/config" "github.com/microsoft/go-sqlcmd/internal/cmdparser" - "github.com/stretchr/testify/assert" "os" "runtime" "testing" @@ -18,8 +17,27 @@ func TestQuery(t *testing.T) { if runtime.GOOS != "windows" { t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)") } + cmdparser.TestSetup(t) + setupContext(t) + cmdparser.TestCmd[*Query]("PRINT") +} + +func TestQueryWithNonDefaultDatabase(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)") + } + + cmdparser.TestSetup(t) + + setupContext(t) + cmdparser.TestCmd[*Query](`--text "PRINT DB_NAME()" --database master`) + + // TODO: Add test validation that DB name was actually master! +} + +func setupContext(t *testing.T) { // if SQLCMDSERVER != "" add an endpoint using the --address if os.Getenv("SQLCMDSERVER") == "" { cmdparser.TestCmd[*config.AddEndpoint]() @@ -33,10 +51,6 @@ func TestQuery(t *testing.T) { if os.Getenv("SQLCMDPASSWORD") != "" && os.Getenv("SQLCMDUSER") != "" { - // sqlcmd uses the SQLCMD_PASSWORD env var, but the tests use the - // SQLCMDPASSWORD env var - err := os.Setenv("SQLCMD_PASSWORD", os.Getenv("SQLCMDPASSWORD")) - assert.Nil(t, err) cmdparser.TestCmd[*config.AddUser]( fmt.Sprintf("--name user1 --username %s", os.Getenv("SQLCMDUSER"))) @@ -45,5 +59,4 @@ func TestQuery(t *testing.T) { cmdparser.TestCmd[*config.AddContext]("--endpoint endpoint") } cmdparser.TestCmd[*config.View]() // displaying the config (info in-case test fails) - cmdparser.TestCmd[*Query]("PRINT") } diff --git a/internal/sql/interface.go b/internal/sql/interface.go index c9b04817..9203bda5 100644 --- a/internal/sql/interface.go +++ b/internal/sql/interface.go @@ -14,5 +14,7 @@ type Sql interface { } type ConnectOptions struct { + Database string + Interactive bool } diff --git a/internal/sql/mssql.go b/internal/sql/mssql.go index b7759f6f..507fb78d 100644 --- a/internal/sql/mssql.go +++ b/internal/sql/mssql.go @@ -40,6 +40,10 @@ func (m *mssql) Connect( ApplicationName: "sqlcmd", } + if options.Database != "" { + connect.Database = options.Database + } + if user == nil { connect.UseTrustedConnection = true } else {