Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [



{
Expand All @@ -24,16 +25,16 @@
"type" : "go",
"request": "launch",
"mode" : "auto",
"program": "${workspaceFolder}/cmd/sqlcmd",
"args" : ["-Q", "EXIT(select 100 as Count)"],
"program": "${workspaceFolder}/cmd/modern",
"args" : ["-Q", "EXIT(select net_transport from sys.dm_exec_connections)"],
},
{
"name" : "Run file query",
"type" : "go",
"request": "launch",
"mode" : "auto",
"program": "${workspaceFolder}/cmd/sqlcmd",
"args" : ["-i", "testdata\\select100.sql"],
"program": "${workspaceFolder}/cmd/modern",
"args" : ["-S", "np:.", "-i", "${workspaceFolder}/cmd/sqlcmd/testdata/select100.sql"],
},
]
}
17 changes: 17 additions & 0 deletions NOTICE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5003,6 +5003,23 @@ third-party archives.

```

## gopkg.in/natefinch/npipe.v2

* Name: gopkg.in/natefinch/npipe.v2
* Version: v2.0.0-20160621034901-c1b8fa8bdcce
* License: [MIT](https://github.com/natefinch/npipe/blob/c1b8fa8bdcce/LICENSE.txt)

```
The MIT License (MIT)
Copyright (c) 2013 npipe authors

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
```

## gopkg.in/yaml.v2

* Name: gopkg.in/yaml.v2
Expand Down
29 changes: 23 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,28 @@ We will be implementing command line switches and behaviors over time. Several s
- The new `--driver-logging-level` command line parameter allows you to see traces from the `go-mssqldb` client driver. Use `64` to see all traces.
- Sqlcmd can now print results using a vertical format. Use the new `-F vertical` command line option to set it. It's also controlled by the `SQLCMDFORMAT` scripting variable.

```

1> select session_id, client_interface_name, program_name from sys.dm_exec_sessions where session_id=@@spid
2> go
session_id 58
client_interface_name go-mssqldb
program_name sqlcmd

```
- Sqlcmd now supports shared memory and named pipe transport. Use the appropriate protocol prefix on the server name to force a protocol
* `lpc` for shared memory, only for a localhost. `sqlcmd -S lpc:.`
* `np` for named pipes. Or use the UNC named pipe path as the server name: `sqlcmd -S \\myserver\pipe\sql\query`
* `tcp` for tcp `sqlcmd -S tcp:myserver,1234`
If no protocol is specified, sqlcmd will attempt to dial in this order: lpc->np->tcp. If dialing a remote host, `lpc` will be skipped.

```
1> select net_transport from sys.dm_exec_connections where session_id=@@spid
2> go
net_transport Named pipe

```

### Azure Active Directory Authentication

This version of sqlcmd supports a broader range of AAD authentication models, based on the [azidentity package](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity). The implementation relies on an AAD Connector in the [driver](https://github.com/microsoft/go-mssqldb).
Expand Down Expand Up @@ -105,18 +127,13 @@ pkg/sqlcmd is consumable by other hosts. Go docs for the package are forthcoming

## Building

To add version data to your build using `go-winres`, add `GOBIN` to your `PATH` then use `go generate`
The version on the binary will match the version tag of the branch.

```sh

go install github.com/tc-hib/go-winres@latest
cd cmd/modern
go generate
build/build

```

Scripts to build the binaries and package them for release will be added in a build folder off the root. We will also add Azure Devops pipeline yml files there to initiate builds and releases. Until then just use `go build ./cmd/sqlcmd` to create a sqlcmd binary.

## Testing

Expand Down
3 changes: 2 additions & 1 deletion cmd/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type SQLCmdArguments struct {
InitialQuery string `short:"q" xor:"input1" help:"Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed."`
// Query to run then exit
Query string `short:"Q" xor:"input2" help:"Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed."`
Server string `short:"S" help:"[tcp:]server[\\instance_name][,port]Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable SQLCMDSERVER."`
Server string `short:"S" help:"[[tcp:]|[lpc:]|[np:]]server[\\instance_name][,port]Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable SQLCMDSERVER."`
// Disable syscommands with a warning
DisableCmdAndWarn bool `short:"X" xor:"syscmd" help:"Disables commands that might compromise system security. Sqlcmd issues a warning and continues."`
// AuthenticationMethod is new for go-sqlcmd
Expand Down Expand Up @@ -291,6 +291,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
// connect using no overrides
err = s.ConnectDb(nil, line == nil)
if err != nil {
s.WriteError(s.GetError(), err)
return 1, err
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ require (
golang.org/x/tools v0.1.12 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.4.0 // indirect
)
16 changes: 15 additions & 1 deletion pkg/sqlcmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package sqlcmd
import (
"fmt"
"net/url"
"strings"

"github.com/microsoft/go-mssqldb/azuread"
)
Expand Down Expand Up @@ -81,7 +82,7 @@ func (connect ConnectSettings) RequiresPassword() bool {

// ConnectionString returns the go-mssql connection string to use for queries
func (connect ConnectSettings) ConnectionString() (connectionString string, err error) {
serverName, instance, port, err := splitServer(connect.ServerName)
serverName, instance, port, protocol, err := splitServer(connect.ServerName)
if serverName == "" {
serverName = "."
}
Expand All @@ -100,6 +101,16 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
if (connect.authenticationMethod() == azuread.ActiveDirectoryMSI || connect.authenticationMethod() == azuread.ActiveDirectoryManagedIdentity) && connect.UserName != "" {
connectionURL.User = url.UserPassword(connect.UserName, connect.Password)
}

if strings.HasPrefix(serverName, `\\`) {
// passing a pipe name of the format \\server\pipe\<pipename>
pipeParts := strings.SplitN(string(serverName[2:]), `\`, 3)
if len(pipeParts) != 3 {
return "", &InvalidServerName
}
serverName = pipeParts[0]
query.Add("pipe", pipeParts[2])
}
if port > 0 {
connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port)
} else {
Expand Down Expand Up @@ -130,6 +141,9 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
if connect.LogLevel > 0 {
query.Add("log", fmt.Sprint(connect.LogLevel))
}
if protocol != "" {
query.Add("protocol", protocol)
}
if connect.ApplicationName != "" {
query.Add(`app name`, connect.ApplicationName)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlcmd/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (e *ArgumentError) IsSqlcmdErr() bool {
// InvalidServerName indicates the SQLCMDSERVER variable has an incorrect format
var InvalidServerName = ArgumentError{
Parameter: "server",
Rule: "server must be of the form [tcp]:server[[/instance]|[,port]]",
Rule: "server must be of the form [[np]|[lpc][tcp]]:server[[/instance]|[,port]]",
}

// VariableError is an error about scripting variables
Expand Down
15 changes: 15 additions & 0 deletions pkg/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@ import (
"github.com/golang-sql/sqlexp"
mssql "github.com/microsoft/go-mssqldb"
"github.com/microsoft/go-mssqldb/msdsn"
_ "github.com/microsoft/go-mssqldb/namedpipe"
_ "github.com/microsoft/go-mssqldb/sharedmemory"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
)

// Note: The order of includes above matters for namedpipe and sharedmemory.
// init() swaps shared memory protocol with tcp so it gets priority when dialing.

var (
// ErrExitRequested tells the hosting application to exit immediately
ErrExitRequested = errors.New("exit")
Expand Down Expand Up @@ -534,3 +539,13 @@ func (s Sqlcmd) Log(_ context.Context, _ msdsn.Log, msg string) {
_, _ = s.GetOutput().Write([]byte("DRIVER:" + msg))
_, _ = s.GetOutput().Write([]byte(SqlcmdEol))
}

func init() {
if len(msdsn.ProtocolParsers) == 3 {
// reorder the protocol parsers to lpc->np->tcp
// ODBC follows this same order.
var tcp = msdsn.ProtocolParsers[0]
msdsn.ProtocolParsers[0] = msdsn.ProtocolParsers[2]
msdsn.ProtocolParsers[2] = tcp
}
}
33 changes: 26 additions & 7 deletions pkg/sqlcmd/sqlcmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (
"io"
"os"
"os/user"
"runtime"
"strings"
"testing"

"github.com/microsoft/go-mssqldb/azuread"
"github.com/microsoft/go-mssqldb/msdsn"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -46,16 +48,20 @@ func TestConnectionStringFromSqlCmd(t *testing.T) {
},
{
&ConnectSettings{TrustServerCertificate: true, UseTrustedConnection: true, Password: pwd, ServerName: `tcp:someserver,1045`, UserName: "someuser"},
"sqlserver://someserver:1045?trustservercertificate=true",
"sqlserver://someserver:1045?protocol=tcp&trustservercertificate=true",
},
{
&ConnectSettings{ServerName: `tcp:someserver,1045`},
"sqlserver://someserver:1045",
"sqlserver://someserver:1045?protocol=tcp",
},
{
&ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd},
fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd),
},
{
&ConnectSettings{ServerName: `\\someserver\pipe\sql\query`},
"sqlserver://someserver?pipe=sql%5Cquery&protocol=np",
},
}

for i, test := range commands {
Expand Down Expand Up @@ -356,16 +362,20 @@ func TestPromptForPasswordNegative(t *testing.T) {
}
v := InitializeVariables(true)
s := New(console, "", v)
c := newConnect(t)
s.Connect.UserName = "someuser"
s.Connect.ServerName = c.ServerName
err := s.ConnectDb(nil, false)
assert.True(t, prompted, "Password prompt not shown for SQL auth")
assert.Error(t, err, "ConnectDb")
prompted = false
s.Connect.AuthenticationMethod = azuread.ActiveDirectoryPassword
err = s.ConnectDb(nil, false)
assert.True(t, prompted, "Password prompt not shown for AD Password auth")
assert.Error(t, err, "ConnectDb")
prompted = false
if canTestAzureAuth() {
s.Connect.AuthenticationMethod = azuread.ActiveDirectoryPassword
err = s.ConnectDb(nil, false)
assert.True(t, prompted, "Password prompt not shown for AD Password auth")
assert.Error(t, err, "ConnectDb")
prompted = false
}
}

func TestPromptForPasswordPositive(t *testing.T) {
Expand Down Expand Up @@ -619,3 +629,12 @@ func newConnect(t testing.TB) *ConnectSettings {
}
return &connect
}

func TestSqlcmdPrefersSharedMemoryProtocol(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip()
}
assert.EqualValuesf(t, "lpc", msdsn.ProtocolParsers[0].Protocol(), "lpc should be first protocol")
assert.EqualValuesf(t, "np", msdsn.ProtocolParsers[1].Protocol(), "np should be second protocol")
assert.EqualValuesf(t, "tcp", msdsn.ProtocolParsers[2].Protocol(), "tcp should be third protocol")
}
59 changes: 38 additions & 21 deletions pkg/sqlcmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,57 @@ package sqlcmd
import (
"strconv"
"strings"

"github.com/microsoft/go-mssqldb/msdsn"
)

// splitServer extracts connection parameters from a server name input
func splitServer(serverName string) (string, string, uint64, error) {
instance := ""
port := uint64(0)
if strings.HasPrefix(serverName, "tcp:") {
if len(serverName) == 4 {
return "", "", 0, &InvalidServerName
func splitServer(serverName string) (string, instance string, port uint64, protocol string, err error) {
instance = ""
port = uint64(0)
protocol = ""
err = nil
// We don't just look for : due to possible IPv6 address
for _, p := range msdsn.ProtocolParsers {
prefix := p.Protocol() + ":"
if strings.HasPrefix(serverName, prefix) {
if len(serverName) == len(prefix) {
serverName = "."
} else {
serverName = serverName[len(prefix):]
}
protocol = p.Protocol()
}
serverName = serverName[4:]
}
serverNameParts := strings.Split(serverName, ",")
if len(serverNameParts) > 2 {
return "", "", 0, &InvalidServerName
}
if len(serverNameParts) == 2 {
var err error
port, err = strconv.ParseUint(serverNameParts[1], 10, 16)
if err != nil {
return "", "", 0, &InvalidServerName
if strings.HasPrefix(serverName, `\\`) {
if protocol != "np" && protocol != "" || len(serverName) == 2 {
return "", "", 0, "", &InvalidServerName
}
serverName = serverNameParts[0]
protocol = "np"
} else {
serverNameParts = strings.Split(serverName, "\\")
serverNameParts := strings.Split(serverName, ",")
if len(serverNameParts) > 2 {
return "", "", 0, &InvalidServerName
return "", "", 0, "", &InvalidServerName
}
if len(serverNameParts) == 2 {
instance = serverNameParts[1]
var err error
port, err = strconv.ParseUint(serverNameParts[1], 10, 16)
if err != nil {
return "", "", 0, "", &InvalidServerName
}
serverName = serverNameParts[0]
} else {
serverNameParts = strings.Split(serverName, "\\")
if len(serverNameParts) > 2 {
return "", "", 0, "", &InvalidServerName
}
if len(serverNameParts) == 2 {
instance = serverNameParts[1]
serverName = serverNameParts[0]
}
}
}
return serverName, instance, port, nil
return serverName, instance, port, protocol, err
}

// padRight appends c instances of s to builder
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlcmd/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (v Variables) SQLCmdUser() string {
}

// SQLCmdServer returns the server connection parameters derived from the SQLCMDSERVER variable value
func (v Variables) SQLCmdServer() (serverName string, instance string, port uint64, err error) {
func (v Variables) SQLCmdServer() (serverName string, instance string, port uint64, protocol string, err error) {
serverName = v[SQLCMDSERVER]
return splitServer(serverName)
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/sqlcmd/variables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,22 @@ func TestSqlServerSplitsName(t *testing.T) {
vars := Variables{
SQLCMDSERVER: `tcp:someserver\someinstance`,
}
serverName, instance, port, err := vars.SQLCmdServer()
serverName, instance, port, protocol, err := vars.SQLCmdServer()
if assert.NoError(t, err, "tcp:server\\someinstance") {
assert.Equal(t, "someserver", serverName, "server name for instance")
assert.Equal(t, uint64(0), port, "port for instance")
assert.Equal(t, "someinstance", instance, "instance for instance")
assert.Equal(t, "tcp", protocol, "protocol for instance")
}
vars = Variables{
SQLCMDSERVER: `tcp:someserver,1111`,
}
serverName, instance, port, err = vars.SQLCmdServer()
serverName, instance, port, protocol, err = vars.SQLCmdServer()
if assert.NoError(t, err, "tcp:server,1111") {
assert.Equal(t, "someserver", serverName, "server name for port number")
assert.Equal(t, uint64(1111), port, "port for port number")
assert.Equal(t, "", instance, "instance for port number")
assert.Equal(t, "tcp", protocol, "protocol for port number")
}
}

Expand Down