From 1854c5cb2c118e0bbb9c06dc9b89e7ac80ed7601 Mon Sep 17 00:00:00 2001 From: chandan jain Date: Mon, 29 Nov 2021 20:53:33 +0530 Subject: [PATCH 01/21] add krb auth support --- go.mod | 7 + go.sum | 14 ++ kerbauth.go | 180 ++++++++++++++++++++++++ kerbauth_test.go | 301 +++++++++++++++++++++++++++++++++++++++++ msdsn/conn_str.go | 76 ++++++++++- msdsn/conn_str_test.go | 38 ++++++ tds.go | 17 ++- 7 files changed, 628 insertions(+), 5 deletions(-) create mode 100644 kerbauth.go create mode 100644 kerbauth_test.go diff --git a/go.mod b/go.mod index ebc02ab8..cc3029b8 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,12 @@ go 1.11 require ( github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe + github.com/hashicorp/go-uuid v1.0.2 // indirect + github.com/jcmturner/gofork v1.0.0 // indirect golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c + gopkg.in/jcmturner/aescts.v1 v1.0.1 // indirect + gopkg.in/jcmturner/dnsutils.v1 v1.0.1 // indirect + gopkg.in/jcmturner/goidentity.v3 v3.0.0 // indirect + gopkg.in/jcmturner/gokrb5.v7 v7.5.0 + gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index 1887801b..433b68f8 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,19 @@ github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= +github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= +github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +gopkg.in/jcmturner/aescts.v1 v1.0.1 h1:cVVZBK2b1zY26haWB4vbBiZrfFQnfbTVrE3xZq6hrEw= +gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= +gopkg.in/jcmturner/dnsutils.v1 v1.0.1 h1:cIuC1OLRGZrld+16ZJvvZxVJeKPsvd5eUIvxfoN5hSM= +gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= +gopkg.in/jcmturner/goidentity.v3 v3.0.0 h1:1duIyWiTaYvVx3YX2CYtpJbUFd7/UuPYCfgXtQ3VTbI= +gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4= +gopkg.in/jcmturner/gokrb5.v7 v7.5.0 h1:a9tsXlIDD9SKxotJMK3niV7rPZAJeX2aD/0yg3qlIrg= +gopkg.in/jcmturner/gokrb5.v7 v7.5.0/go.mod h1:l8VISx+WGYp+Fp7KRbsiUuXTTOnxIc3Tuvyavf11/WM= +gopkg.in/jcmturner/rpc.v1 v1.1.0 h1:QHIUxTX1ISuAv9dD2wJ9HWQVuWDX/Zc0PfeC2tjc4rU= +gopkg.in/jcmturner/rpc.v1 v1.1.0/go.mod h1:YIdkC4XfD6GXbzje11McwsDuOlZQSb9W4vfLvuNnlv8= diff --git a/kerbauth.go b/kerbauth.go new file mode 100644 index 00000000..103ba611 --- /dev/null +++ b/kerbauth.go @@ -0,0 +1,180 @@ +package mssql + +import ( + "fmt" + "io/ioutil" + "log" + "os" + "strings" + + "gopkg.in/jcmturner/gokrb5.v7/client" + "gopkg.in/jcmturner/gokrb5.v7/config" + "gopkg.in/jcmturner/gokrb5.v7/credentials" + "gopkg.in/jcmturner/gokrb5.v7/keytab" + "gopkg.in/jcmturner/gokrb5.v7/messages" + "gopkg.in/jcmturner/gokrb5.v7/spnego" + "gopkg.in/jcmturner/gokrb5.v7/types" +) + +type Krb5ClientState int + +const ( + ContextFlagREADY = 128 + /* initiator states */ + InitiatorStart Krb5ClientState = iota + InitiatorRestart + InitiatorWaitForMutal + InitiatorReady +) + +type krb5Auth struct { + username string + realm string + service string + password string + port string + krb5ConfFile string + krbFile string + initkrbwithkeytab string + krb5Client *client.Client + state Krb5ClientState +} + +var clientWithKeytab = client.NewClientWithKeytab +var loadCCache = credentials.LoadCCache +var clientFromCCache = client.NewClientFromCCache +var spnegoNewNegToken = spnego.NewNegTokenInitKRB5 +var spnegoToken spnego.SPNEGOToken +var spnegoUnmarshal = spnegoToken.Unmarshal +var kt = &keytab.Keytab{} +var ktUnmarshal = kt.Unmarshal + +var negTokenMarshal = func(negTok spnego.NegTokenInit) ([]byte, error) { + return negTok.Marshal() +} +var getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { + return cl.GetServiceTicket(spn) +} + +func getKRB5Auth(user, service, krb5Conf, krbFile, initkrbwithkeytab, password string) (auth, bool) { + if krb5Conf == "" { + krb5Conf = "/etc/krb5.conf" + } + var port string + var realm string + var serviceStr string + + params1 := strings.Split(service, ":") + if len(params1) != 2 { + return nil, false + } + + params2 := strings.Split(params1[1], "@") + if len(params2) == 1 { + port = params1[1] + } else if len(params2) == 2 { + port = params2[0] + } else if len(params2) != 2 { + return nil, false + } + + params3 := strings.Split(service, "@") + if len(params3) == 1 { + serviceStr = params3[0] + params3 = strings.Split(params1[0], "/") + params3 = strings.Split(params3[1], ".") + realm = params3[1] + "." + params3[2] + } else if len(params3) == 2 { + realm = params3[1] + serviceStr = params3[0] + } + + return &krb5Auth{ + username: user, + service: serviceStr, + port: port, + realm: realm, + krb5ConfFile: krb5Conf, + krbFile: krbFile, + password: password, + initkrbwithkeytab: initkrbwithkeytab, + }, true + +} + +func (auth *krb5Auth) InitialBytes() ([]byte, error) { + + krb5CnfFile, _ := os.Open(auth.krb5ConfFile) + c, _ := config.NewConfigFromReader(krb5CnfFile) + + // Set to lookup KDCs in DNS + c.LibDefaults.DNSLookupKDC = false + + var err error + var cl *client.Client + // Init keytab from conf + if auth.initkrbwithkeytab == "true" { + + keytabConf, err := ioutil.ReadFile(auth.krbFile) + if err != nil { + return []byte{}, err + } + if err = ktUnmarshal([]byte(keytabConf)); err != nil { + log.Printf("unmarshal keytabConf failed: %v", err) + return []byte{}, err + } + + // Init krb5 client and login + cl = clientWithKeytab(auth.username, auth.realm, kt, c, client.DisablePAFXFAST(true)) + + } else { + cache, err := loadCCache(auth.krbFile) + if err != nil { + log.Println(err) + return []byte{}, err + } + + cl, err = clientFromCCache(cache, c) + if err != nil { + log.Println(err) + return []byte{}, err + } + } + + auth.krb5Client = cl + auth.state = InitiatorStart + + tkt, sessionKey, err := getServiceTicket(cl, auth.service) + if err != nil { + return []byte{}, err + } + + negTok, err := spnegoNewNegToken(auth.krb5Client, tkt, sessionKey) + if err != nil { + fmt.Println(err) + return []byte{}, err + } + + outToken, err := negTokenMarshal(negTok) + if err != nil { + fmt.Println(err) + return []byte{}, err + } + auth.state = InitiatorWaitForMutal + return outToken, nil +} + +func (auth *krb5Auth) Free() { + auth.krb5Client.Destroy() +} + +func (auth *krb5Auth) NextBytes(token []byte) ([]byte, error) { + + if err := spnegoUnmarshal(token); err != nil { + err := fmt.Errorf("unmarshal APRep token failed: %w", err) + return []byte{}, err + } + + auth.state = InitiatorReady + return []byte{}, nil +} diff --git a/kerbauth_test.go b/kerbauth_test.go new file mode 100644 index 00000000..7211a031 --- /dev/null +++ b/kerbauth_test.go @@ -0,0 +1,301 @@ +package mssql + +import ( + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "gopkg.in/jcmturner/gokrb5.v7/client" + "gopkg.in/jcmturner/gokrb5.v7/config" + "gopkg.in/jcmturner/gokrb5.v7/credentials" + "gopkg.in/jcmturner/gokrb5.v7/keytab" + "gopkg.in/jcmturner/gokrb5.v7/messages" + "gopkg.in/jcmturner/gokrb5.v7/spnego" + "gopkg.in/jcmturner/gokrb5.v7/types" +) + +func createKrbFile(filename string) string { + ans := []byte{84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 116, 101, 115, 116, 32, 102, 105, 108, 101, 46} + err := ioutil.WriteFile(filename, ans, 0644) + if err != nil { + fmt.Println("Could not write file") + } + + filedirectory := filepath.Dir(filename) + thepath, _ := filepath.Abs(filedirectory) + filePath := thepath + "/" + filename + + return filePath +} + +func deleteFile(filename string) { + if _, err := os.Stat(filename); err == nil { + err = os.Remove(filename) + if err != nil { + fmt.Println("Could not delete file") + } + } +} + +func TestGetKRB5Auth(t *testing.T) { + keytabFile := createKrbFile("admin.keytab") + got, _ := getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", keytabFile, "true", "") + var keytab auth = &krb5Auth{username: "", + realm: "domain.com", + service: "MSSQLSvc/mssql.domain.com:1433", + password: "", + port: "1433", + krb5ConfFile: "/etc/krb5.conf", + krbFile: keytabFile, + initkrbwithkeytab: "true", + state: 0} + + res := reflect.DeepEqual(got, keytab) + + if !res { + t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) + } + + krbcacheFile := createKrbFile("krb5ccache_1000") + got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "", krbcacheFile, "true", "") + keytab = &krb5Auth{username: "", + realm: "domain.com", + service: "MSSQLSvc/mssql.domain.com:1433", + password: "", + port: "1433", + krb5ConfFile: "/etc/krb5.conf", + krbFile: krbcacheFile, + initkrbwithkeytab: "true", + state: 0} + + res = reflect.DeepEqual(got, keytab) + + if !res { + t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) + } + + _, val := getKRB5Auth("", "MSSQLSvc/mssql.domain.com", "", keytabFile, "true", "") + + if val { + t.Errorf("Failed to get correct krb5Auth object") + } + + got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", "", keytabFile, "true", "") + keytab = &krb5Auth{username: "", + realm: "DOMAIN.COM", + service: "MSSQLSvc/mssql.domain.com:1433", + password: "", + port: "1433", + krb5ConfFile: "/etc/krb5.conf", + krbFile: keytabFile, + initkrbwithkeytab: "true", + state: 0} + + res = reflect.DeepEqual(got, keytab) + + if !res { + t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) + } + + _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", "", keytabFile, "true", "") + + if val { + t.Errorf("Failed to get correct krb5Auth object due to incorrect service name") + } + + deleteFile(krbcacheFile) + deleteFile(keytabFile) + +} + +func TestInitialBytes(t *testing.T) { + + krbcacheFile := createKrbFile("krbcache_1000") + krbObj := &krb5Auth{username: "", + realm: "domain.com", + service: "MSSQLSvc/mssql.domain.com:1433", + password: "", + port: "1433", + krb5ConfFile: "/etc/krb5.conf", + krbFile: krbcacheFile, + initkrbwithkeytab: "", + state: 0, + } + + loadCCache = func(cpath string) (*credentials.CCache, error) { + return &credentials.CCache{}, nil + } + + clientFromCCache = func(c *credentials.CCache, krb5conf *config.Config, settings ...func(*client.Settings)) (*client.Client, error) { + return &client.Client{}, nil + } + + getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { + return messages.Ticket{}, types.EncryptionKey{}, nil + } + spnegoNewNegToken = func(cl *client.Client, tkt messages.Ticket, sessionKey types.EncryptionKey) (spnego.NegTokenInit, error) { + return spnego.NegTokenInit{}, nil + } + + _, err := krbObj.InitialBytes() + if err != nil { + t.Errorf(err.Error()) + } + + loadCCache = func(cpath string) (*credentials.CCache, error) { + return &credentials.CCache{}, errors.New("Error loading cache file") + } + + _, err = krbObj.InitialBytes() + if err == nil { + t.Errorf(err.Error()) + } + + loadCCache = func(cpath string) (*credentials.CCache, error) { + return &credentials.CCache{}, nil + } + clientFromCCache = func(c *credentials.CCache, krb5conf *config.Config, settings ...func(*client.Settings)) (*client.Client, error) { + return &client.Client{}, errors.New("Failed to create a client from CCache") + } + _, err = krbObj.InitialBytes() + if err == nil { + t.Errorf(err.Error()) + } + + clientFromCCache = func(c *credentials.CCache, krb5conf *config.Config, settings ...func(*client.Settings)) (*client.Client, error) { + return &client.Client{}, nil + } + getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { + return messages.Ticket{}, types.EncryptionKey{}, errors.New("Failed to create service ticket") + } + + _, err = krbObj.InitialBytes() + if err == nil { + t.Errorf(err.Error()) + } + + keytabFile := createKrbFile("admin.keytab") + krbObj.initkrbwithkeytab = "true" + krbObj.krbFile = keytabFile + + getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { + return messages.Ticket{}, types.EncryptionKey{}, nil + } + ktUnmarshal = func(b []byte) error { + return nil + } + + clientWithKeytab = func(username string, realm string, kt *keytab.Keytab, krb5conf *config.Config, settings ...func(*client.Settings)) *client.Client { + return &client.Client{} + } + + _, err = krbObj.InitialBytes() + if err != nil { + t.Errorf(err.Error()) + } + + krbObj.krbFile = "Test" + + _, err = krbObj.InitialBytes() + if err == nil { + t.Errorf(err.Error()) + } + + krbObj.krbFile = keytabFile + ktUnmarshal = func(b []byte) error { + return errors.New("Failed to unmarshal keytab file") + } + + _, err = krbObj.InitialBytes() + if err == nil { + t.Errorf(err.Error()) + } + + ktUnmarshal = func(b []byte) error { + return nil + } + + spnegoNewNegToken = func(cl *client.Client, tkt messages.Ticket, sessionKey types.EncryptionKey) (spnego.NegTokenInit, error) { + return spnego.NegTokenInit{}, errors.New("Failed to create a new spnego token") + } + + _, err = krbObj.InitialBytes() + if err == nil { + t.Errorf(err.Error()) + } + + spnegoNewNegToken = func(cl *client.Client, tkt messages.Ticket, sessionKey types.EncryptionKey) (spnego.NegTokenInit, error) { + return spnego.NegTokenInit{}, nil + } + + negTokenMarshal = func(negTok spnego.NegTokenInit) ([]byte, error) { + return []byte{}, errors.New("Failed to marshal neg token") + } + + _, err = krbObj.InitialBytes() + if err == nil { + t.Errorf(err.Error()) + } + + deleteFile(krbcacheFile) + deleteFile(keytabFile) +} + +func TestNextBytes(t *testing.T) { + ans := []byte{84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 116, 101, 115, 116, 32, 102, 105, 108, 101, 46} + spnegoUnmarshal = func(b []byte) error { + return nil + } + + keytabFile := createKrbFile("admin.keytab") + var krbObj auth = &krb5Auth{username: "", + realm: "domain.com", + service: "MSSQLSvc/mssql.domain.com:1433", + password: "", + port: "1433", + krb5ConfFile: "/etc/krb5.conf", + krbFile: keytabFile, + initkrbwithkeytab: "true", + state: 0} + + _, err := krbObj.NextBytes(ans) + if err != nil { + t.Errorf("Error getting next byte") + } + + spnegoUnmarshal = func(b []byte) error { + return errors.New("Failed to unmarshal") + } + + _, err = krbObj.NextBytes(ans) + if err == nil { + t.Errorf("Should fail to unmarshal but passed") + } + + deleteFile(keytabFile) +} + +func TestFree(t *testing.T) { + keytabFile := createKrbFile("admin.keytab") + kt := &keytab.Keytab{} + c := &config.Config{} + cl := client.NewClientWithKeytab("Administrator", "DOMAIN.COM", kt, c, client.DisablePAFXFAST(true)) + var krbObj auth = &krb5Auth{username: "", + realm: "domain.com", + service: "MSSQLSvc/mssql.domain.com:1433", + password: "", + port: "1433", + krb5ConfFile: "/etc/krb5.conf", + krbFile: keytabFile, + initkrbwithkeytab: "true", + state: 0, + krb5Client: cl, + } + + krbObj.Free() + deleteFile(keytabFile) +} diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 4804036a..03b3cc0a 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -14,6 +14,8 @@ import ( "unicode" ) +const defaultServerPort = 1433 + type ( Encryption int Log uint64 @@ -72,6 +74,15 @@ type Config struct { ConnTimeout time.Duration // Use context for timeouts. KeepAlive time.Duration // Leave at default. PacketSize uint16 + + // Kerberos authentication fields + + Krb5ConfFile string + KrbCache string + Realm string + Initkrbwithkeytab string + Keytabfile string + EnableKerberos string } func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string) (*tls.Config, error) { @@ -173,6 +184,40 @@ func Parse(dsn string) (Config, map[string]string, error) { } } + p.EnableKerberos = params["enablekerberos"] + if p.EnableKerberos == "true" { + missingParam := checkMissingKRBConfig(params) + if missingParam != "" { + return p, params, fmt.Errorf(" %s cannot be empty", missingParam) + } + + realm, ok := params["realm"] + if ok { + p.Realm = realm + } + + krbCache, ok := params["krbcache"] + if ok { + p.KrbCache = krbCache + } + + krb5ConfFile, ok := params["krb5conffile"] + if ok { + p.Krb5ConfFile = krb5ConfFile + } + + initkrbwithkeytab, ok := params["initkrbwithkeytab"] + if ok { + p.Initkrbwithkeytab = initkrbwithkeytab + } + + keytabfile, ok := params["keytabfile"] + if ok { + p.Keytabfile = keytabfile + } + + } + // https://msdn.microsoft.com/en-us/library/dd341108.aspx // // Do not set a connection timeout. Use Context to manage such things. @@ -259,7 +304,7 @@ func Parse(dsn string) (Config, map[string]string, error) { if ok { p.ServerSPN = serverSPN } else { - p.ServerSPN = generateSpn(p.Host, p.Port) + p.ServerSPN = generateSpn(p.Host, resolveServerPort(p.Port), p.Realm) } workstation, ok := params["workstation id"] @@ -318,6 +363,20 @@ func Parse(dsn string) (Config, map[string]string, error) { return p, params, nil } +func checkMissingKRBConfig(c map[string]string) (missingParam string) { + if c["initkrbwithkeytab"] == "true" { + if c["keytabfile"] == "" { + missingParam = "keytabfile" + } + if c["realm"] == "" { + missingParam = "realm" + } + } else if c["krbcache"] == "" { + missingParam = "krbcache" + } + return +} + // convert connectionParams to url style connection string // used mostly for testing func (p Config) URL() *url.URL { @@ -597,6 +656,17 @@ func normalizeOdbcKey(s string) string { return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace)) } -func generateSpn(host string, port uint64) string { - return fmt.Sprintf("MSSQLSvc/%s:%d", host, port) +func resolveServerPort(port uint64) uint64 { + if port == 0 { + return defaultServerPort + } + + return port +} + +func generateSpn(host string, port uint64, realm string) string { + if realm == "" { + return fmt.Sprintf("MSSQLSvc/%s:%d", host, port) + } + return fmt.Sprintf("MSSQLSvc/%s:%d@%s", host, port, realm) } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 594b5b3d..c8c191e8 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -196,3 +196,41 @@ func TestConnParseRoundTripFixed(t *testing.T) { t.Fatal("Parameters do not match after roundtrip", params, rtParams) } } + +func TestInvalidConnectionStringKerberos(t *testing.T) { + + connStrings := []string{ + "server=server;port=1345;realm=domain;trustservercertificate=true;keytabfile=/path/to/administrator2.keytab.keytab;enablekerberos=true", + "server=server;port=1345;realm=domain;trustservercertificate=true;krbcache=;enablekerberos=true", + "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;enablekerberos=true", + "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/administrator2.keytab;enablekerberos=true", + "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/administrator2.keytab;enablekerberos=true;initkrbwithkeytab=false", + "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;enablekerberos=true;initkrbwithkeytab=true", + } + for _, connStr := range connStrings { + _, _, err := Parse(connStr) + if err == nil { + t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr) + continue + } else { + t.Logf("Connection failed for %s as expected with error %v", connStr, err) + } + } +} + +func TestValidConnectionStringKerberos(t *testing.T) { + connStrings := []string{ + "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/home/user/Pictures/admin.keytab;enablekerberos=true;initkrbwithkeytab=true", + "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;krbcache=/tmp/krb5cc_1000;enablekerberos=true", + } + + for _, connStr := range connStrings { + _, _, err := Parse(connStr) + if err == nil { + t.Logf("Connection string was parsed successfully %s", connStrings) + } else { + t.Errorf("Connection string %s failed to parse with error %s", connStrings, err) + } + } + +} diff --git a/tds.go b/tds.go index dbc9b211..def37fb5 100644 --- a/tds.go +++ b/tds.go @@ -1011,6 +1011,10 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont if err != nil { return nil, err } + _, ok := auth.(*krb5Auth) + if ok { + l.UserName = p.User + } l.OptionFlags2 |= fIntSecurity return l, nil @@ -1157,8 +1161,17 @@ initiate_connection: } } } - - auth, authOk := getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) + var auth auth + var authOk bool + if p.EnableKerberos == "true" { + if p.Initkrbwithkeytab == "true" { + auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Krb5ConfFile, p.Keytabfile, p.Initkrbwithkeytab, p.Password) + } else { + auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Krb5ConfFile, p.KrbCache, p.Initkrbwithkeytab, p.Password) + } + } else { + auth, authOk = getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) + } if authOk { defer auth.Free() } else { From f2e846d7c214aac410fb10b844f4f11ed7aa83a6 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Tue, 30 Nov 2021 19:39:46 +0530 Subject: [PATCH 02/21] worked on review comments --- kerbauth.go | 64 ++++++++++++++++++++----------------- kerbauth_test.go | 71 +++++++++++++++++++++--------------------- msdsn/conn_str.go | 56 ++++++++++++++++++++++++--------- msdsn/conn_str_test.go | 4 +-- tds.go | 13 +++----- 5 files changed, 119 insertions(+), 89 deletions(-) diff --git a/kerbauth.go b/kerbauth.go index 103ba611..bcee8829 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -3,8 +3,8 @@ package mssql import ( "fmt" "io/ioutil" - "log" "os" + "strconv" "strings" "gopkg.in/jcmturner/gokrb5.v7/client" @@ -19,7 +19,6 @@ import ( type Krb5ClientState int const ( - ContextFlagREADY = 128 /* initiator states */ InitiatorStart Krb5ClientState = iota InitiatorRestart @@ -32,10 +31,10 @@ type krb5Auth struct { realm string service string password string - port string + port uint64 krb5ConfFile string krbFile string - initkrbwithkeytab string + initkrbwithkeytab bool krb5Client *client.Client state Krb5ClientState } @@ -56,13 +55,11 @@ var getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, typ return cl.GetServiceTicket(spn) } -func getKRB5Auth(user, service, krb5Conf, krbFile, initkrbwithkeytab, password string) (auth, bool) { - if krb5Conf == "" { - krb5Conf = "/etc/krb5.conf" - } - var port string +func getKRB5Auth(user, service, krb5Conf, krbFile, password string, initkrbwithkeytab bool) (auth, bool) { + var port uint64 var realm string var serviceStr string + var err error params1 := strings.Split(service, ":") if len(params1) != 2 { @@ -70,23 +67,36 @@ func getKRB5Auth(user, service, krb5Conf, krbFile, initkrbwithkeytab, password s } params2 := strings.Split(params1[1], "@") - if len(params2) == 1 { - port = params1[1] - } else if len(params2) == 2 { - port = params2[0] - } else if len(params2) != 2 { + switch len(params2) { + case 1: + port, err = strconv.ParseUint(params1[1], 10, 16) + if err != nil { + return nil, false + } + + case 2: + port, err = strconv.ParseUint(params2[0], 10, 16) + if err != nil { + return nil, false + } + default: return nil, false } params3 := strings.Split(service, "@") - if len(params3) == 1 { + switch len(params3) { + case 1: serviceStr = params3[0] params3 = strings.Split(params1[0], "/") params3 = strings.Split(params3[1], ".") realm = params3[1] + "." + params3[2] - } else if len(params3) == 2 { + + case 2: realm = params3[1] serviceStr = params3[0] + + default: + return nil, false } return &krb5Auth{ @@ -103,24 +113,28 @@ func getKRB5Auth(user, service, krb5Conf, krbFile, initkrbwithkeytab, password s } func (auth *krb5Auth) InitialBytes() ([]byte, error) { - - krb5CnfFile, _ := os.Open(auth.krb5ConfFile) - c, _ := config.NewConfigFromReader(krb5CnfFile) + var err error + krb5CnfFile, err := os.Open(auth.krb5ConfFile) + if err != nil { + return []byte{}, err + } + c, err := config.NewConfigFromReader(krb5CnfFile) + if err != nil { + return []byte{}, err + } // Set to lookup KDCs in DNS c.LibDefaults.DNSLookupKDC = false - var err error var cl *client.Client // Init keytab from conf - if auth.initkrbwithkeytab == "true" { + if auth.initkrbwithkeytab { keytabConf, err := ioutil.ReadFile(auth.krbFile) if err != nil { return []byte{}, err } if err = ktUnmarshal([]byte(keytabConf)); err != nil { - log.Printf("unmarshal keytabConf failed: %v", err) return []byte{}, err } @@ -130,13 +144,11 @@ func (auth *krb5Auth) InitialBytes() ([]byte, error) { } else { cache, err := loadCCache(auth.krbFile) if err != nil { - log.Println(err) return []byte{}, err } cl, err = clientFromCCache(cache, c) if err != nil { - log.Println(err) return []byte{}, err } } @@ -151,13 +163,11 @@ func (auth *krb5Auth) InitialBytes() ([]byte, error) { negTok, err := spnegoNewNegToken(auth.krb5Client, tkt, sessionKey) if err != nil { - fmt.Println(err) return []byte{}, err } outToken, err := negTokenMarshal(negTok) if err != nil { - fmt.Println(err) return []byte{}, err } auth.state = InitiatorWaitForMutal @@ -169,12 +179,10 @@ func (auth *krb5Auth) Free() { } func (auth *krb5Auth) NextBytes(token []byte) ([]byte, error) { - if err := spnegoUnmarshal(token); err != nil { err := fmt.Errorf("unmarshal APRep token failed: %w", err) return []byte{}, err } - auth.state = InitiatorReady return []byte{}, nil } diff --git a/kerbauth_test.go b/kerbauth_test.go index 7211a031..bb21a451 100644 --- a/kerbauth_test.go +++ b/kerbauth_test.go @@ -2,7 +2,6 @@ package mssql import ( "errors" - "fmt" "io/ioutil" "os" "path/filepath" @@ -18,13 +17,13 @@ import ( "gopkg.in/jcmturner/gokrb5.v7/types" ) -func createKrbFile(filename string) string { +func createKrbFile(filename string, t *testing.T) string { + //The byte array is used to create a basic file for testing purpose ans := []byte{84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 116, 101, 115, 116, 32, 102, 105, 108, 101, 46} err := ioutil.WriteFile(filename, ans, 0644) if err != nil { - fmt.Println("Could not write file") + t.Errorf("Could not write file") } - filedirectory := filepath.Dir(filename) thepath, _ := filepath.Abs(filedirectory) filePath := thepath + "/" + filename @@ -32,26 +31,26 @@ func createKrbFile(filename string) string { return filePath } -func deleteFile(filename string) { +func deleteFile(filename string, t *testing.T) { if _, err := os.Stat(filename); err == nil { err = os.Remove(filename) if err != nil { - fmt.Println("Could not delete file") + t.Errorf("Could not delete file: %v", filename) } } } func TestGetKRB5Auth(t *testing.T) { - keytabFile := createKrbFile("admin.keytab") - got, _ := getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", keytabFile, "true", "") + keytabFile := createKrbFile("admin.keytab", t) + got, _ := getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", keytabFile, "", true) var keytab auth = &krb5Auth{username: "", realm: "domain.com", service: "MSSQLSvc/mssql.domain.com:1433", password: "", - port: "1433", + port: 1433, krb5ConfFile: "/etc/krb5.conf", krbFile: keytabFile, - initkrbwithkeytab: "true", + initkrbwithkeytab: true, state: 0} res := reflect.DeepEqual(got, keytab) @@ -60,16 +59,16 @@ func TestGetKRB5Auth(t *testing.T) { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - krbcacheFile := createKrbFile("krb5ccache_1000") - got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "", krbcacheFile, "true", "") + krbcacheFile := createKrbFile("krb5ccache_1000", t) + got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", krbcacheFile, "", true) keytab = &krb5Auth{username: "", realm: "domain.com", service: "MSSQLSvc/mssql.domain.com:1433", password: "", - port: "1433", + port: 1433, krb5ConfFile: "/etc/krb5.conf", krbFile: krbcacheFile, - initkrbwithkeytab: "true", + initkrbwithkeytab: true, state: 0} res = reflect.DeepEqual(got, keytab) @@ -78,21 +77,21 @@ func TestGetKRB5Auth(t *testing.T) { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - _, val := getKRB5Auth("", "MSSQLSvc/mssql.domain.com", "", keytabFile, "true", "") + _, val := getKRB5Auth("", "MSSQLSvc/mssql.domain.com", "/etc/krb5.conf", keytabFile, "", true) if val { t.Errorf("Failed to get correct krb5Auth object") } - got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", "", keytabFile, "true", "") + got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", "/etc/krb5.conf", keytabFile, "", true) keytab = &krb5Auth{username: "", realm: "DOMAIN.COM", service: "MSSQLSvc/mssql.domain.com:1433", password: "", - port: "1433", + port: 1433, krb5ConfFile: "/etc/krb5.conf", krbFile: keytabFile, - initkrbwithkeytab: "true", + initkrbwithkeytab: true, state: 0} res = reflect.DeepEqual(got, keytab) @@ -101,28 +100,28 @@ func TestGetKRB5Auth(t *testing.T) { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", "", keytabFile, "true", "") + _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", "", keytabFile, "", true) if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect service name") } - deleteFile(krbcacheFile) - deleteFile(keytabFile) + defer deleteFile(krbcacheFile, t) + defer deleteFile(keytabFile, t) } func TestInitialBytes(t *testing.T) { - krbcacheFile := createKrbFile("krbcache_1000") + krbcacheFile := createKrbFile("krbcache_1000", t) krbObj := &krb5Auth{username: "", realm: "domain.com", service: "MSSQLSvc/mssql.domain.com:1433", password: "", - port: "1433", + port: 1433, krb5ConfFile: "/etc/krb5.conf", krbFile: krbcacheFile, - initkrbwithkeytab: "", + initkrbwithkeytab: false, state: 0, } @@ -178,8 +177,8 @@ func TestInitialBytes(t *testing.T) { t.Errorf(err.Error()) } - keytabFile := createKrbFile("admin.keytab") - krbObj.initkrbwithkeytab = "true" + keytabFile := createKrbFile("admin.keytab", t) + krbObj.initkrbwithkeytab = true krbObj.krbFile = keytabFile getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { @@ -241,8 +240,8 @@ func TestInitialBytes(t *testing.T) { t.Errorf(err.Error()) } - deleteFile(krbcacheFile) - deleteFile(keytabFile) + defer deleteFile(krbcacheFile, t) + defer deleteFile(keytabFile, t) } func TestNextBytes(t *testing.T) { @@ -251,15 +250,15 @@ func TestNextBytes(t *testing.T) { return nil } - keytabFile := createKrbFile("admin.keytab") + keytabFile := createKrbFile("admin.keytab", t) var krbObj auth = &krb5Auth{username: "", realm: "domain.com", service: "MSSQLSvc/mssql.domain.com:1433", password: "", - port: "1433", + port: 1433, krb5ConfFile: "/etc/krb5.conf", krbFile: keytabFile, - initkrbwithkeytab: "true", + initkrbwithkeytab: true, state: 0} _, err := krbObj.NextBytes(ans) @@ -276,11 +275,11 @@ func TestNextBytes(t *testing.T) { t.Errorf("Should fail to unmarshal but passed") } - deleteFile(keytabFile) + defer deleteFile(keytabFile, t) } func TestFree(t *testing.T) { - keytabFile := createKrbFile("admin.keytab") + keytabFile := createKrbFile("admin.keytab", t) kt := &keytab.Keytab{} c := &config.Config{} cl := client.NewClientWithKeytab("Administrator", "DOMAIN.COM", kt, c, client.DisablePAFXFAST(true)) @@ -288,14 +287,14 @@ func TestFree(t *testing.T) { realm: "domain.com", service: "MSSQLSvc/mssql.domain.com:1433", password: "", - port: "1433", + port: 1433, krb5ConfFile: "/etc/krb5.conf", krbFile: keytabFile, - initkrbwithkeytab: "true", + initkrbwithkeytab: true, state: 0, krb5Client: cl, } krbObj.Free() - deleteFile(keytabFile) + defer deleteFile(keytabFile, t) } diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 03b3cc0a..bc963245 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -76,13 +76,24 @@ type Config struct { PacketSize uint16 // Kerberos authentication fields + // Path to krb5.conf file that contains Kerberos configuration information + Krb5ConfFilePath string - Krb5ConfFile string - KrbCache string - Realm string - Initkrbwithkeytab string - Keytabfile string - EnableKerberos string + // Credential cache path + KrbCachePath string + + // A Kerberos realm is the domain over which a Kerberos authentication server has the authority + // to authenticate a user, host or service. + KrbRealm string + + // Flag to authenticate using keytab file + Initkrbwithkeytab bool + + // Path to keytab file that stores long-term keys for one or more principals + KeytabFilePath string + + // Flag to enable kerberos authentication + EnableKerberos bool } func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string) (*tls.Config, error) { @@ -184,8 +195,16 @@ func Parse(dsn string) (Config, map[string]string, error) { } } - p.EnableKerberos = params["enablekerberos"] - if p.EnableKerberos == "true" { + enablekerberos, ok := params["enablekerberos"] + if ok { + var err error + p.EnableKerberos, err = strconv.ParseBool(enablekerberos) + if err != nil { + f := "invalid enablekerberos flag '%v': %v" + return p, params, fmt.Errorf(f, enablekerberos, err.Error()) + } + } + if p.EnableKerberos { missingParam := checkMissingKRBConfig(params) if missingParam != "" { return p, params, fmt.Errorf(" %s cannot be empty", missingParam) @@ -193,27 +212,32 @@ func Parse(dsn string) (Config, map[string]string, error) { realm, ok := params["realm"] if ok { - p.Realm = realm + p.KrbRealm = realm } krbCache, ok := params["krbcache"] if ok { - p.KrbCache = krbCache + p.KrbCachePath = krbCache } krb5ConfFile, ok := params["krb5conffile"] if ok { - p.Krb5ConfFile = krb5ConfFile + p.Krb5ConfFilePath = krb5ConfFile } initkrbwithkeytab, ok := params["initkrbwithkeytab"] if ok { - p.Initkrbwithkeytab = initkrbwithkeytab + var err error + p.Initkrbwithkeytab, err = strconv.ParseBool(initkrbwithkeytab) + if err != nil { + f := "invalid initkrbwithkeytab flag '%v': %v" + return p, params, fmt.Errorf(f, initkrbwithkeytab, err.Error()) + } } keytabfile, ok := params["keytabfile"] if ok { - p.Keytabfile = keytabfile + p.KeytabFilePath = keytabfile } } @@ -304,7 +328,7 @@ func Parse(dsn string) (Config, map[string]string, error) { if ok { p.ServerSPN = serverSPN } else { - p.ServerSPN = generateSpn(p.Host, resolveServerPort(p.Port), p.Realm) + p.ServerSPN = generateSpn(p.Host, resolveServerPort(p.Port), p.KrbRealm) } workstation, ok := params["workstation id"] @@ -364,6 +388,10 @@ func Parse(dsn string) (Config, map[string]string, error) { } func checkMissingKRBConfig(c map[string]string) (missingParam string) { + if c["krb5conffile"] == "" { + missingParam = "krb5conffile" + return + } if c["initkrbwithkeytab"] == "true" { if c["keytabfile"] == "" { missingParam = "keytabfile" diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index c8c191e8..53ef6d29 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -200,7 +200,7 @@ func TestConnParseRoundTripFixed(t *testing.T) { func TestInvalidConnectionStringKerberos(t *testing.T) { connStrings := []string{ - "server=server;port=1345;realm=domain;trustservercertificate=true;keytabfile=/path/to/administrator2.keytab.keytab;enablekerberos=true", + "server=server;port=1345;realm=domain;trustservercertificate=true;keytabfile=/path/to/administrator2.keytab;enablekerberos=true", "server=server;port=1345;realm=domain;trustservercertificate=true;krbcache=;enablekerberos=true", "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;enablekerberos=true", "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/administrator2.keytab;enablekerberos=true", @@ -220,7 +220,7 @@ func TestInvalidConnectionStringKerberos(t *testing.T) { func TestValidConnectionStringKerberos(t *testing.T) { connStrings := []string{ - "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/home/user/Pictures/admin.keytab;enablekerberos=true;initkrbwithkeytab=true", + "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/admin.keytab;enablekerberos=true;initkrbwithkeytab=true", "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;krbcache=/tmp/krb5cc_1000;enablekerberos=true", } diff --git a/tds.go b/tds.go index def37fb5..a453b852 100644 --- a/tds.go +++ b/tds.go @@ -1011,11 +1011,6 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont if err != nil { return nil, err } - _, ok := auth.(*krb5Auth) - if ok { - l.UserName = p.User - } - l.OptionFlags2 |= fIntSecurity return l, nil @@ -1163,11 +1158,11 @@ initiate_connection: } var auth auth var authOk bool - if p.EnableKerberos == "true" { - if p.Initkrbwithkeytab == "true" { - auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Krb5ConfFile, p.Keytabfile, p.Initkrbwithkeytab, p.Password) + if p.EnableKerberos { + if p.Initkrbwithkeytab { + auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Krb5ConfFilePath, p.KeytabFilePath, p.Password, p.Initkrbwithkeytab) } else { - auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Krb5ConfFile, p.KrbCache, p.Initkrbwithkeytab, p.Password) + auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Krb5ConfFilePath, p.KrbCachePath, p.Password, p.Initkrbwithkeytab) } } else { auth, authOk = getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) From b46e98ebdb3058839c7fe5b72b328d2d30075c71 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Tue, 30 Nov 2021 20:40:58 +0530 Subject: [PATCH 03/21] Rename kerberos variables. --- kerbauth.go | 24 +++++++----------------- kerbauth_test.go | 14 +++++++------- types.go | 11 +++++++++++ 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/kerbauth.go b/kerbauth.go index bcee8829..e0046a7d 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -16,27 +16,17 @@ import ( "gopkg.in/jcmturner/gokrb5.v7/types" ) -type Krb5ClientState int - -const ( - /* initiator states */ - InitiatorStart Krb5ClientState = iota - InitiatorRestart - InitiatorWaitForMutal - InitiatorReady -) - type krb5Auth struct { username string realm string - service string + serverSPN string password string port uint64 krb5ConfFile string krbFile string initkrbwithkeytab bool krb5Client *client.Client - state Krb5ClientState + state krb5ClientState } var clientWithKeytab = client.NewClientWithKeytab @@ -55,13 +45,13 @@ var getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, typ return cl.GetServiceTicket(spn) } -func getKRB5Auth(user, service, krb5Conf, krbFile, password string, initkrbwithkeytab bool) (auth, bool) { +func getKRB5Auth(user, serverSPN, krb5Conf, krbFile, password string, initkrbwithkeytab bool) (auth, bool) { var port uint64 var realm string var serviceStr string var err error - params1 := strings.Split(service, ":") + params1 := strings.Split(serverSPN, ":") if len(params1) != 2 { return nil, false } @@ -83,7 +73,7 @@ func getKRB5Auth(user, service, krb5Conf, krbFile, password string, initkrbwithk return nil, false } - params3 := strings.Split(service, "@") + params3 := strings.Split(serverSPN, "@") switch len(params3) { case 1: serviceStr = params3[0] @@ -101,7 +91,7 @@ func getKRB5Auth(user, service, krb5Conf, krbFile, password string, initkrbwithk return &krb5Auth{ username: user, - service: serviceStr, + serverSPN: serviceStr, port: port, realm: realm, krb5ConfFile: krb5Conf, @@ -156,7 +146,7 @@ func (auth *krb5Auth) InitialBytes() ([]byte, error) { auth.krb5Client = cl auth.state = InitiatorStart - tkt, sessionKey, err := getServiceTicket(cl, auth.service) + tkt, sessionKey, err := getServiceTicket(cl, auth.serverSPN) if err != nil { return []byte{}, err } diff --git a/kerbauth_test.go b/kerbauth_test.go index bb21a451..0d46b9c9 100644 --- a/kerbauth_test.go +++ b/kerbauth_test.go @@ -45,7 +45,7 @@ func TestGetKRB5Auth(t *testing.T) { got, _ := getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", keytabFile, "", true) var keytab auth = &krb5Auth{username: "", realm: "domain.com", - service: "MSSQLSvc/mssql.domain.com:1433", + serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, krb5ConfFile: "/etc/krb5.conf", @@ -63,7 +63,7 @@ func TestGetKRB5Auth(t *testing.T) { got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", krbcacheFile, "", true) keytab = &krb5Auth{username: "", realm: "domain.com", - service: "MSSQLSvc/mssql.domain.com:1433", + serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, krb5ConfFile: "/etc/krb5.conf", @@ -86,7 +86,7 @@ func TestGetKRB5Auth(t *testing.T) { got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", "/etc/krb5.conf", keytabFile, "", true) keytab = &krb5Auth{username: "", realm: "DOMAIN.COM", - service: "MSSQLSvc/mssql.domain.com:1433", + serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, krb5ConfFile: "/etc/krb5.conf", @@ -103,7 +103,7 @@ func TestGetKRB5Auth(t *testing.T) { _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", "", keytabFile, "", true) if val { - t.Errorf("Failed to get correct krb5Auth object due to incorrect service name") + t.Errorf("Failed to get correct krb5Auth object due to incorrect serverSPN name") } defer deleteFile(krbcacheFile, t) @@ -116,7 +116,7 @@ func TestInitialBytes(t *testing.T) { krbcacheFile := createKrbFile("krbcache_1000", t) krbObj := &krb5Auth{username: "", realm: "domain.com", - service: "MSSQLSvc/mssql.domain.com:1433", + serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, krb5ConfFile: "/etc/krb5.conf", @@ -253,7 +253,7 @@ func TestNextBytes(t *testing.T) { keytabFile := createKrbFile("admin.keytab", t) var krbObj auth = &krb5Auth{username: "", realm: "domain.com", - service: "MSSQLSvc/mssql.domain.com:1433", + serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, krb5ConfFile: "/etc/krb5.conf", @@ -285,7 +285,7 @@ func TestFree(t *testing.T) { cl := client.NewClientWithKeytab("Administrator", "DOMAIN.COM", kt, c, client.DisablePAFXFAST(true)) var krbObj auth = &krb5Auth{username: "", realm: "domain.com", - service: "MSSQLSvc/mssql.domain.com:1433", + serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, krb5ConfFile: "/etc/krb5.conf", diff --git a/types.go b/types.go index 822a7d86..c5be2741 100644 --- a/types.go +++ b/types.go @@ -113,6 +113,17 @@ type xmlInfo struct { XmlSchemaCollection string } +// Kerberos Client State +type krb5ClientState int + +const ( + /* initiator states */ + InitiatorStart krb5ClientState = iota + InitiatorRestart + InitiatorWaitForMutal + InitiatorReady +) + func readTypeInfo(r *tdsBuffer) (res typeInfo) { res.TypeId = r.byte() switch res.TypeId { From 1f85924749fcc6e17a8750e194429f8a8b98b7c0 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Wed, 1 Dec 2021 17:22:05 +0530 Subject: [PATCH 04/21] Removed leading space --- kerbauth.go | 5 ----- kerbauth_test.go | 7 ------- 2 files changed, 12 deletions(-) diff --git a/kerbauth.go b/kerbauth.go index e0046a7d..6d078ff3 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -63,7 +63,6 @@ func getKRB5Auth(user, serverSPN, krb5Conf, krbFile, password string, initkrbwit if err != nil { return nil, false } - case 2: port, err = strconv.ParseUint(params2[0], 10, 16) if err != nil { @@ -80,11 +79,9 @@ func getKRB5Auth(user, serverSPN, krb5Conf, krbFile, password string, initkrbwit params3 = strings.Split(params1[0], "/") params3 = strings.Split(params3[1], ".") realm = params3[1] + "." + params3[2] - case 2: realm = params3[1] serviceStr = params3[0] - default: return nil, false } @@ -99,7 +96,6 @@ func getKRB5Auth(user, serverSPN, krb5Conf, krbFile, password string, initkrbwit password: password, initkrbwithkeytab: initkrbwithkeytab, }, true - } func (auth *krb5Auth) InitialBytes() ([]byte, error) { @@ -119,7 +115,6 @@ func (auth *krb5Auth) InitialBytes() ([]byte, error) { var cl *client.Client // Init keytab from conf if auth.initkrbwithkeytab { - keytabConf, err := ioutil.ReadFile(auth.krbFile) if err != nil { return []byte{}, err diff --git a/kerbauth_test.go b/kerbauth_test.go index 0d46b9c9..f195435a 100644 --- a/kerbauth_test.go +++ b/kerbauth_test.go @@ -72,13 +72,11 @@ func TestGetKRB5Auth(t *testing.T) { state: 0} res = reflect.DeepEqual(got, keytab) - if !res { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } _, val := getKRB5Auth("", "MSSQLSvc/mssql.domain.com", "/etc/krb5.conf", keytabFile, "", true) - if val { t.Errorf("Failed to get correct krb5Auth object") } @@ -95,24 +93,19 @@ func TestGetKRB5Auth(t *testing.T) { state: 0} res = reflect.DeepEqual(got, keytab) - if !res { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", "", keytabFile, "", true) - if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect serverSPN name") } - defer deleteFile(krbcacheFile, t) defer deleteFile(keytabFile, t) - } func TestInitialBytes(t *testing.T) { - krbcacheFile := createKrbFile("krbcache_1000", t) krbObj := &krb5Auth{username: "", realm: "domain.com", From 2fe0bca5e43343eb8fc67941e3b7e3cfec8435b1 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Mon, 6 Dec 2021 12:54:24 +0530 Subject: [PATCH 05/21] unexport initiator states --- kerbauth.go | 7 +++---- types.go | 10 +++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/kerbauth.go b/kerbauth.go index 6d078ff3..71996d73 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -99,7 +99,6 @@ func getKRB5Auth(user, serverSPN, krb5Conf, krbFile, password string, initkrbwit } func (auth *krb5Auth) InitialBytes() ([]byte, error) { - var err error krb5CnfFile, err := os.Open(auth.krb5ConfFile) if err != nil { return []byte{}, err @@ -139,7 +138,7 @@ func (auth *krb5Auth) InitialBytes() ([]byte, error) { } auth.krb5Client = cl - auth.state = InitiatorStart + auth.state = initiatorStart tkt, sessionKey, err := getServiceTicket(cl, auth.serverSPN) if err != nil { @@ -155,7 +154,7 @@ func (auth *krb5Auth) InitialBytes() ([]byte, error) { if err != nil { return []byte{}, err } - auth.state = InitiatorWaitForMutal + auth.state = initiatorWaitForMutal return outToken, nil } @@ -168,6 +167,6 @@ func (auth *krb5Auth) NextBytes(token []byte) ([]byte, error) { err := fmt.Errorf("unmarshal APRep token failed: %w", err) return []byte{}, err } - auth.state = InitiatorReady + auth.state = initiatorReady return []byte{}, nil } diff --git a/types.go b/types.go index c5be2741..155efffd 100644 --- a/types.go +++ b/types.go @@ -117,11 +117,11 @@ type xmlInfo struct { type krb5ClientState int const ( - /* initiator states */ - InitiatorStart krb5ClientState = iota - InitiatorRestart - InitiatorWaitForMutal - InitiatorReady + // Initiator states + initiatorStart krb5ClientState = iota + initiatorRestart + initiatorWaitForMutal + initiatorReady ) func readTypeInfo(r *tdsBuffer) (res typeInfo) { From d3615094cd6127e4f95a442b5077b3eebeb35d69 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Tue, 7 Dec 2021 16:26:29 +0530 Subject: [PATCH 06/21] removed the global vars in kerbauth.go --- kerbauth.go | 45 +++--------- kerbauth_test.go | 182 ++++++++++++----------------------------------- 2 files changed, 58 insertions(+), 169 deletions(-) diff --git a/kerbauth.go b/kerbauth.go index 71996d73..265772f5 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -11,9 +11,7 @@ import ( "gopkg.in/jcmturner/gokrb5.v7/config" "gopkg.in/jcmturner/gokrb5.v7/credentials" "gopkg.in/jcmturner/gokrb5.v7/keytab" - "gopkg.in/jcmturner/gokrb5.v7/messages" "gopkg.in/jcmturner/gokrb5.v7/spnego" - "gopkg.in/jcmturner/gokrb5.v7/types" ) type krb5Auth struct { @@ -29,26 +27,9 @@ type krb5Auth struct { state krb5ClientState } -var clientWithKeytab = client.NewClientWithKeytab -var loadCCache = credentials.LoadCCache -var clientFromCCache = client.NewClientFromCCache -var spnegoNewNegToken = spnego.NewNegTokenInitKRB5 -var spnegoToken spnego.SPNEGOToken -var spnegoUnmarshal = spnegoToken.Unmarshal -var kt = &keytab.Keytab{} -var ktUnmarshal = kt.Unmarshal - -var negTokenMarshal = func(negTok spnego.NegTokenInit) ([]byte, error) { - return negTok.Marshal() -} -var getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { - return cl.GetServiceTicket(spn) -} - func getKRB5Auth(user, serverSPN, krb5Conf, krbFile, password string, initkrbwithkeytab bool) (auth, bool) { var port uint64 - var realm string - var serviceStr string + var realm, serviceStr string var err error params1 := strings.Split(serverSPN, ":") @@ -107,10 +88,9 @@ func (auth *krb5Auth) InitialBytes() ([]byte, error) { if err != nil { return []byte{}, err } - // Set to lookup KDCs in DNS c.LibDefaults.DNSLookupKDC = false - + var kt = &keytab.Keytab{} var cl *client.Client // Init keytab from conf if auth.initkrbwithkeytab { @@ -118,39 +98,35 @@ func (auth *krb5Auth) InitialBytes() ([]byte, error) { if err != nil { return []byte{}, err } - if err = ktUnmarshal([]byte(keytabConf)); err != nil { + if err = kt.Unmarshal([]byte(keytabConf)); err != nil { return []byte{}, err } - // Init krb5 client and login - cl = clientWithKeytab(auth.username, auth.realm, kt, c, client.DisablePAFXFAST(true)) - + cl = client.NewClientWithKeytab(auth.username, auth.realm, kt, c, client.DisablePAFXFAST(true)) } else { - cache, err := loadCCache(auth.krbFile) + cache, err := credentials.LoadCCache(auth.krbFile) if err != nil { return []byte{}, err } - cl, err = clientFromCCache(cache, c) + cl, err = client.NewClientFromCCache(cache, c) if err != nil { return []byte{}, err } } - auth.krb5Client = cl auth.state = initiatorStart - - tkt, sessionKey, err := getServiceTicket(cl, auth.serverSPN) + tkt, sessionKey, err := cl.GetServiceTicket(auth.serverSPN) if err != nil { return []byte{}, err } - negTok, err := spnegoNewNegToken(auth.krb5Client, tkt, sessionKey) + negTok, err := spnego.NewNegTokenInitKRB5(auth.krb5Client, tkt, sessionKey) if err != nil { return []byte{}, err } - outToken, err := negTokenMarshal(negTok) + outToken, err := negTok.Marshal() if err != nil { return []byte{}, err } @@ -163,7 +139,8 @@ func (auth *krb5Auth) Free() { } func (auth *krb5Auth) NextBytes(token []byte) ([]byte, error) { - if err := spnegoUnmarshal(token); err != nil { + var spnegoToken spnego.SPNEGOToken + if err := spnegoToken.Unmarshal(token); err != nil { err := fmt.Errorf("unmarshal APRep token failed: %w", err) return []byte{}, err } diff --git a/kerbauth_test.go b/kerbauth_test.go index f195435a..0d5e98eb 100644 --- a/kerbauth_test.go +++ b/kerbauth_test.go @@ -1,7 +1,6 @@ package mssql import ( - "errors" "io/ioutil" "os" "path/filepath" @@ -10,17 +9,12 @@ import ( "gopkg.in/jcmturner/gokrb5.v7/client" "gopkg.in/jcmturner/gokrb5.v7/config" - "gopkg.in/jcmturner/gokrb5.v7/credentials" "gopkg.in/jcmturner/gokrb5.v7/keytab" - "gopkg.in/jcmturner/gokrb5.v7/messages" - "gopkg.in/jcmturner/gokrb5.v7/spnego" - "gopkg.in/jcmturner/gokrb5.v7/types" ) func createKrbFile(filename string, t *testing.T) string { - //The byte array is used to create a basic file for testing purpose - ans := []byte{84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 116, 101, 115, 116, 32, 102, 105, 108, 101, 46} - err := ioutil.WriteFile(filename, ans, 0644) + file := []byte("This is a test file") + err := ioutil.WriteFile(filename, file, 0644) if err != nil { t.Errorf("Could not write file") } @@ -32,18 +26,20 @@ func createKrbFile(filename string, t *testing.T) string { } func deleteFile(filename string, t *testing.T) { - if _, err := os.Stat(filename); err == nil { - err = os.Remove(filename) - if err != nil { - t.Errorf("Could not delete file: %v", filename) + defer func() { + if _, err := os.Stat(filename); err == nil { + err = os.Remove(filename) + if err != nil { + t.Errorf("Could not delete file: %v", filename) + } } - } + }() } func TestGetKRB5Auth(t *testing.T) { keytabFile := createKrbFile("admin.keytab", t) got, _ := getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", keytabFile, "", true) - var keytab auth = &krb5Auth{username: "", + keytab := &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", @@ -78,7 +74,7 @@ func TestGetKRB5Auth(t *testing.T) { _, val := getKRB5Auth("", "MSSQLSvc/mssql.domain.com", "/etc/krb5.conf", keytabFile, "", true) if val { - t.Errorf("Failed to get correct krb5Auth object") + t.Errorf("Failed to get correct krb5Auth object: no port defined") } got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", "/etc/krb5.conf", keytabFile, "", true) @@ -101,178 +97,93 @@ func TestGetKRB5Auth(t *testing.T) { if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect serverSPN name") } - defer deleteFile(krbcacheFile, t) - defer deleteFile(keytabFile, t) + + _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:port@domain.com", "", keytabFile, "", true) + if val { + t.Errorf("Failed to get correct krb5Auth object due to incorrect port") + } + + _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:port", "", keytabFile, "", true) + if val { + t.Errorf("Failed to get correct krb5Auth object due to incorrect port") + } + + deleteFile(krbcacheFile, t) + deleteFile(keytabFile, t) } func TestInitialBytes(t *testing.T) { + krb5ConfFile := createKrbFile("krb5.conf", t) krbcacheFile := createKrbFile("krbcache_1000", t) krbObj := &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: "/etc/krb5.conf", + krb5ConfFile: krb5ConfFile, krbFile: krbcacheFile, initkrbwithkeytab: false, state: 0, } - loadCCache = func(cpath string) (*credentials.CCache, error) { - return &credentials.CCache{}, nil - } - - clientFromCCache = func(c *credentials.CCache, krb5conf *config.Config, settings ...func(*client.Settings)) (*client.Client, error) { - return &client.Client{}, nil - } - - getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { - return messages.Ticket{}, types.EncryptionKey{}, nil - } - spnegoNewNegToken = func(cl *client.Client, tkt messages.Ticket, sessionKey types.EncryptionKey) (spnego.NegTokenInit, error) { - return spnego.NegTokenInit{}, nil - } - _, err := krbObj.InitialBytes() - if err != nil { - t.Errorf(err.Error()) - } - - loadCCache = func(cpath string) (*credentials.CCache, error) { - return &credentials.CCache{}, errors.New("Error loading cache file") - } - - _, err = krbObj.InitialBytes() - if err == nil { - t.Errorf(err.Error()) - } - - loadCCache = func(cpath string) (*credentials.CCache, error) { - return &credentials.CCache{}, nil - } - clientFromCCache = func(c *credentials.CCache, krb5conf *config.Config, settings ...func(*client.Settings)) (*client.Client, error) { - return &client.Client{}, errors.New("Failed to create a client from CCache") - } - _, err = krbObj.InitialBytes() if err == nil { - t.Errorf(err.Error()) - } - - clientFromCCache = func(c *credentials.CCache, krb5conf *config.Config, settings ...func(*client.Settings)) (*client.Client, error) { - return &client.Client{}, nil - } - getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { - return messages.Ticket{}, types.EncryptionKey{}, errors.New("Failed to create service ticket") - } - - _, err = krbObj.InitialBytes() - if err == nil { - t.Errorf(err.Error()) + t.Errorf("Failed to get Initial bytes") } keytabFile := createKrbFile("admin.keytab", t) - krbObj.initkrbwithkeytab = true krbObj.krbFile = keytabFile - - getServiceTicket = func(cl *client.Client, spn string) (messages.Ticket, types.EncryptionKey, error) { - return messages.Ticket{}, types.EncryptionKey{}, nil - } - ktUnmarshal = func(b []byte) error { - return nil - } - - clientWithKeytab = func(username string, realm string, kt *keytab.Keytab, krb5conf *config.Config, settings ...func(*client.Settings)) *client.Client { - return &client.Client{} - } - - _, err = krbObj.InitialBytes() - if err != nil { - t.Errorf(err.Error()) - } - - krbObj.krbFile = "Test" - - _, err = krbObj.InitialBytes() - if err == nil { - t.Errorf(err.Error()) - } - - krbObj.krbFile = keytabFile - ktUnmarshal = func(b []byte) error { - return errors.New("Failed to unmarshal keytab file") - } + krbObj.initkrbwithkeytab = true _, err = krbObj.InitialBytes() if err == nil { - t.Errorf(err.Error()) - } - - ktUnmarshal = func(b []byte) error { - return nil - } - - spnegoNewNegToken = func(cl *client.Client, tkt messages.Ticket, sessionKey types.EncryptionKey) (spnego.NegTokenInit, error) { - return spnego.NegTokenInit{}, errors.New("Failed to create a new spnego token") + t.Errorf("Failed to get Initial bytes") } + krbObj.krb5ConfFile = "test/krb5.conf" _, err = krbObj.InitialBytes() if err == nil { - t.Errorf(err.Error()) - } - - spnegoNewNegToken = func(cl *client.Client, tkt messages.Ticket, sessionKey types.EncryptionKey) (spnego.NegTokenInit, error) { - return spnego.NegTokenInit{}, nil - } - - negTokenMarshal = func(negTok spnego.NegTokenInit) ([]byte, error) { - return []byte{}, errors.New("Failed to marshal neg token") + t.Errorf("Should failed to get Initial bytes as the krb5.conf file path is wrong") } + krbObj.krb5ConfFile = krb5ConfFile + krbObj.krbFile = keytabFile + ".test" _, err = krbObj.InitialBytes() if err == nil { - t.Errorf(err.Error()) + t.Errorf("Should failed to get Initial bytes as the krb5.conf file path is wrong") } - defer deleteFile(krbcacheFile, t) - defer deleteFile(keytabFile, t) + deleteFile(krbcacheFile, t) + deleteFile(keytabFile, t) + deleteFile(krb5ConfFile, t) } func TestNextBytes(t *testing.T) { - ans := []byte{84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 116, 101, 115, 116, 32, 102, 105, 108, 101, 46} - spnegoUnmarshal = func(b []byte) error { - return nil - } - + ans := []byte{} keytabFile := createKrbFile("admin.keytab", t) + krb5ConfFile := createKrbFile("krb5.conf", t) var krbObj auth = &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: "/etc/krb5.conf", + krb5ConfFile: krb5ConfFile, krbFile: keytabFile, initkrbwithkeytab: true, state: 0} _, err := krbObj.NextBytes(ans) - if err != nil { - t.Errorf("Error getting next byte") - } - - spnegoUnmarshal = func(b []byte) error { - return errors.New("Failed to unmarshal") - } - - _, err = krbObj.NextBytes(ans) if err == nil { - t.Errorf("Should fail to unmarshal but passed") + t.Errorf("Error getting next byte") } - defer deleteFile(keytabFile, t) + deleteFile(keytabFile, t) + deleteFile(krb5ConfFile, t) } func TestFree(t *testing.T) { keytabFile := createKrbFile("admin.keytab", t) + krb5ConfFile := createKrbFile("krb5.conf", t) kt := &keytab.Keytab{} c := &config.Config{} cl := client.NewClientWithKeytab("Administrator", "DOMAIN.COM", kt, c, client.DisablePAFXFAST(true)) @@ -281,7 +192,7 @@ func TestFree(t *testing.T) { serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: "/etc/krb5.conf", + krb5ConfFile: krb5ConfFile, krbFile: keytabFile, initkrbwithkeytab: true, state: 0, @@ -289,5 +200,6 @@ func TestFree(t *testing.T) { } krbObj.Free() - defer deleteFile(keytabFile, t) + deleteFile(keytabFile, t) + deleteFile(krb5ConfFile, t) } From 86e0074af7020d360e335c045513f4b00de58d61 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Wed, 8 Dec 2021 17:25:55 +0530 Subject: [PATCH 07/21] migrated gokrb from v7 to v8. --- go.mod | 11 +--- go.sum | 48 +++++++++++---- kerbauth.go | 52 +++++----------- kerbauth_test.go | 135 ++++++++++++++--------------------------- msdsn/conn_str.go | 70 +++++++++++++++++---- msdsn/conn_str_test.go | 43 +++++++++++-- tds.go | 5 +- 7 files changed, 199 insertions(+), 165 deletions(-) diff --git a/go.mod b/go.mod index cc3029b8..34cd8cdc 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,7 @@ go 1.11 require ( github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe - github.com/hashicorp/go-uuid v1.0.2 // indirect - github.com/jcmturner/gofork v1.0.0 // indirect - golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c - gopkg.in/jcmturner/aescts.v1 v1.0.1 // indirect - gopkg.in/jcmturner/dnsutils.v1 v1.0.1 // indirect - gopkg.in/jcmturner/goidentity.v3 v3.0.0 // indirect - gopkg.in/jcmturner/gokrb5.v7 v7.5.0 - gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect + github.com/jcmturner/gokrb5/v8 v8.4.2 + github.com/stretchr/testify v1.7.0 // indirect + golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9 ) diff --git a/go.sum b/go.sum index 433b68f8..b89ebbed 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,43 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.2 h1:6ZIM6b/JJN0X8UM43ZOM6Z4SJzla+a/u7scXFJzodkA= +github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9 h1:umElSU9WZirRdgu2yFHY0ayQkEnKiOC1TtM3fWXFnoU= +golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -gopkg.in/jcmturner/aescts.v1 v1.0.1 h1:cVVZBK2b1zY26haWB4vbBiZrfFQnfbTVrE3xZq6hrEw= -gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= -gopkg.in/jcmturner/dnsutils.v1 v1.0.1 h1:cIuC1OLRGZrld+16ZJvvZxVJeKPsvd5eUIvxfoN5hSM= -gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= -gopkg.in/jcmturner/goidentity.v3 v3.0.0 h1:1duIyWiTaYvVx3YX2CYtpJbUFd7/UuPYCfgXtQ3VTbI= -gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4= -gopkg.in/jcmturner/gokrb5.v7 v7.5.0 h1:a9tsXlIDD9SKxotJMK3niV7rPZAJeX2aD/0yg3qlIrg= -gopkg.in/jcmturner/gokrb5.v7 v7.5.0/go.mod h1:l8VISx+WGYp+Fp7KRbsiUuXTTOnxIc3Tuvyavf11/WM= -gopkg.in/jcmturner/rpc.v1 v1.1.0 h1:QHIUxTX1ISuAv9dD2wJ9HWQVuWDX/Zc0PfeC2tjc4rU= -gopkg.in/jcmturner/rpc.v1 v1.1.0/go.mod h1:YIdkC4XfD6GXbzje11McwsDuOlZQSb9W4vfLvuNnlv8= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/kerbauth.go b/kerbauth.go index 265772f5..74804da4 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -2,16 +2,14 @@ package mssql import ( "fmt" - "io/ioutil" - "os" "strconv" "strings" - "gopkg.in/jcmturner/gokrb5.v7/client" - "gopkg.in/jcmturner/gokrb5.v7/config" - "gopkg.in/jcmturner/gokrb5.v7/credentials" - "gopkg.in/jcmturner/gokrb5.v7/keytab" - "gopkg.in/jcmturner/gokrb5.v7/spnego" + "github.com/jcmturner/gokrb5/v8/client" + "github.com/jcmturner/gokrb5/v8/config" + "github.com/jcmturner/gokrb5/v8/credentials" + "github.com/jcmturner/gokrb5/v8/keytab" + "github.com/jcmturner/gokrb5/v8/spnego" ) type krb5Auth struct { @@ -20,14 +18,15 @@ type krb5Auth struct { serverSPN string password string port uint64 - krb5ConfFile string - krbFile string + krb5Config *config.Config + krbKeytab *keytab.Keytab + krbCache *credentials.CCache initkrbwithkeytab bool krb5Client *client.Client state krb5ClientState } -func getKRB5Auth(user, serverSPN, krb5Conf, krbFile, password string, initkrbwithkeytab bool) (auth, bool) { +func getKRB5Auth(user, password, serverSPN string, krb5Conf *config.Config, keytabContent *keytab.Keytab, cacheContent *credentials.CCache, initkrbwithkeytab bool) (auth, bool) { var port uint64 var realm, serviceStr string var err error @@ -72,44 +71,25 @@ func getKRB5Auth(user, serverSPN, krb5Conf, krbFile, password string, initkrbwit serverSPN: serviceStr, port: port, realm: realm, - krb5ConfFile: krb5Conf, - krbFile: krbFile, + krb5Config: krb5Conf, + krbKeytab: keytabContent, + krbCache: cacheContent, password: password, initkrbwithkeytab: initkrbwithkeytab, }, true } func (auth *krb5Auth) InitialBytes() ([]byte, error) { - krb5CnfFile, err := os.Open(auth.krb5ConfFile) - if err != nil { - return []byte{}, err - } - c, err := config.NewConfigFromReader(krb5CnfFile) - if err != nil { - return []byte{}, err - } // Set to lookup KDCs in DNS - c.LibDefaults.DNSLookupKDC = false - var kt = &keytab.Keytab{} + auth.krb5Config.LibDefaults.DNSLookupKDC = false var cl *client.Client + var err error // Init keytab from conf if auth.initkrbwithkeytab { - keytabConf, err := ioutil.ReadFile(auth.krbFile) - if err != nil { - return []byte{}, err - } - if err = kt.Unmarshal([]byte(keytabConf)); err != nil { - return []byte{}, err - } // Init krb5 client and login - cl = client.NewClientWithKeytab(auth.username, auth.realm, kt, c, client.DisablePAFXFAST(true)) + cl = client.NewWithKeytab(auth.username, auth.realm, auth.krbKeytab, auth.krb5Config, client.DisablePAFXFAST(true)) } else { - cache, err := credentials.LoadCCache(auth.krbFile) - if err != nil { - return []byte{}, err - } - - cl, err = client.NewClientFromCCache(cache, c) + cl, err = client.NewFromCCache(auth.krbCache, auth.krb5Config) if err != nil { return []byte{}, err } diff --git a/kerbauth_test.go b/kerbauth_test.go index 0d5e98eb..6c33ded3 100644 --- a/kerbauth_test.go +++ b/kerbauth_test.go @@ -1,70 +1,47 @@ package mssql import ( - "io/ioutil" - "os" - "path/filepath" "reflect" "testing" - "gopkg.in/jcmturner/gokrb5.v7/client" - "gopkg.in/jcmturner/gokrb5.v7/config" - "gopkg.in/jcmturner/gokrb5.v7/keytab" + "github.com/jcmturner/gokrb5/v8/client" + "github.com/jcmturner/gokrb5/v8/config" + "github.com/jcmturner/gokrb5/v8/credentials" + "github.com/jcmturner/gokrb5/v8/keytab" ) -func createKrbFile(filename string, t *testing.T) string { - file := []byte("This is a test file") - err := ioutil.WriteFile(filename, file, 0644) - if err != nil { - t.Errorf("Could not write file") - } - filedirectory := filepath.Dir(filename) - thepath, _ := filepath.Abs(filedirectory) - filePath := thepath + "/" + filename - - return filePath -} - -func deleteFile(filename string, t *testing.T) { - defer func() { - if _, err := os.Stat(filename); err == nil { - err = os.Remove(filename) - if err != nil { - t.Errorf("Could not delete file: %v", filename) - } - } - }() -} - func TestGetKRB5Auth(t *testing.T) { - keytabFile := createKrbFile("admin.keytab", t) - got, _ := getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", keytabFile, "", true) + krbConf := &config.Config{} + krbKeytab := &keytab.Keytab{} + krbCache := &credentials.CCache{} + + got, _ := getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:1433", krbConf, krbKeytab, krbCache, true) keytab := &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: "/etc/krb5.conf", - krbFile: keytabFile, + krb5Config: krbConf, + krbKeytab: krbKeytab, + krbCache: krbCache, initkrbwithkeytab: true, state: 0} res := reflect.DeepEqual(got, keytab) - if !res { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - krbcacheFile := createKrbFile("krb5ccache_1000", t) - got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", "/etc/krb5.conf", krbcacheFile, "", true) + got, _ = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:1433", krbConf, krbKeytab, krbCache, false) keytab = &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: "/etc/krb5.conf", - krbFile: krbcacheFile, - initkrbwithkeytab: true, + krb5Config: krbConf, + krbKeytab: krbKeytab, + krbCache: krbCache, + initkrbwithkeytab: false, state: 0} res = reflect.DeepEqual(got, keytab) @@ -72,19 +49,20 @@ func TestGetKRB5Auth(t *testing.T) { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - _, val := getKRB5Auth("", "MSSQLSvc/mssql.domain.com", "/etc/krb5.conf", keytabFile, "", true) + _, val := getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com", krbConf, krbKeytab, krbCache, true) if val { t.Errorf("Failed to get correct krb5Auth object: no port defined") } - got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", "/etc/krb5.conf", keytabFile, "", true) + got, _ = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", krbConf, krbKeytab, krbCache, true) keytab = &krb5Auth{username: "", realm: "DOMAIN.COM", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: "/etc/krb5.conf", - krbFile: keytabFile, + krb5Config: krbConf, + krbKeytab: krbKeytab, + krbCache: krbCache, initkrbwithkeytab: true, state: 0} @@ -93,35 +71,34 @@ func TestGetKRB5Auth(t *testing.T) { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", "", keytabFile, "", true) + _, val = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", krbConf, krbKeytab, krbCache, true) if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect serverSPN name") } - _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:port@domain.com", "", keytabFile, "", true) + _, val = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:port@domain.com", krbConf, krbKeytab, krbCache, true) if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect port") } - _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:port", "", keytabFile, "", true) + _, val = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:port", krbConf, krbKeytab, krbCache, true) if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect port") } - - deleteFile(krbcacheFile, t) - deleteFile(keytabFile, t) } func TestInitialBytes(t *testing.T) { - krb5ConfFile := createKrbFile("krb5.conf", t) - krbcacheFile := createKrbFile("krbcache_1000", t) + krbConf := &config.Config{} + krbKeytab := &keytab.Keytab{} + krbCache := &credentials.CCache{} krbObj := &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: krb5ConfFile, - krbFile: krbcacheFile, + krb5Config: krbConf, + krbKeytab: krbKeytab, + krbCache: krbCache, initkrbwithkeytab: false, state: 0, } @@ -131,44 +108,27 @@ func TestInitialBytes(t *testing.T) { t.Errorf("Failed to get Initial bytes") } - keytabFile := createKrbFile("admin.keytab", t) - krbObj.krbFile = keytabFile krbObj.initkrbwithkeytab = true - _, err = krbObj.InitialBytes() if err == nil { t.Errorf("Failed to get Initial bytes") } - - krbObj.krb5ConfFile = "test/krb5.conf" - _, err = krbObj.InitialBytes() - if err == nil { - t.Errorf("Should failed to get Initial bytes as the krb5.conf file path is wrong") - } - - krbObj.krb5ConfFile = krb5ConfFile - krbObj.krbFile = keytabFile + ".test" - _, err = krbObj.InitialBytes() - if err == nil { - t.Errorf("Should failed to get Initial bytes as the krb5.conf file path is wrong") - } - - deleteFile(krbcacheFile, t) - deleteFile(keytabFile, t) - deleteFile(krb5ConfFile, t) } func TestNextBytes(t *testing.T) { ans := []byte{} - keytabFile := createKrbFile("admin.keytab", t) - krb5ConfFile := createKrbFile("krb5.conf", t) + krbConf := &config.Config{} + krbKeytab := &keytab.Keytab{} + krbCache := &credentials.CCache{} + var krbObj auth = &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: krb5ConfFile, - krbFile: keytabFile, + krb5Config: krbConf, + krbKeytab: krbKeytab, + krbCache: krbCache, initkrbwithkeytab: true, state: 0} @@ -176,30 +136,27 @@ func TestNextBytes(t *testing.T) { if err == nil { t.Errorf("Error getting next byte") } - - deleteFile(keytabFile, t) - deleteFile(krb5ConfFile, t) } func TestFree(t *testing.T) { - keytabFile := createKrbFile("admin.keytab", t) - krb5ConfFile := createKrbFile("krb5.conf", t) + krbConf := &config.Config{} + krbKeytab := &keytab.Keytab{} + krbCache := &credentials.CCache{} kt := &keytab.Keytab{} c := &config.Config{} - cl := client.NewClientWithKeytab("Administrator", "DOMAIN.COM", kt, c, client.DisablePAFXFAST(true)) + cl := client.NewWithKeytab("Administrator", "DOMAIN.COM", kt, c, client.DisablePAFXFAST(true)) + var krbObj auth = &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", password: "", port: 1433, - krb5ConfFile: krb5ConfFile, - krbFile: keytabFile, + krb5Config: krbConf, + krbKeytab: krbKeytab, + krbCache: krbCache, initkrbwithkeytab: true, state: 0, krb5Client: cl, } - krbObj.Free() - deleteFile(keytabFile, t) - deleteFile(krb5ConfFile, t) } diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index bc963245..0b4eb708 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -12,6 +12,10 @@ import ( "strings" "time" "unicode" + + "github.com/jcmturner/gokrb5/v8/config" + "github.com/jcmturner/gokrb5/v8/credentials" + "github.com/jcmturner/gokrb5/v8/keytab" ) const defaultServerPort = 1433 @@ -75,12 +79,11 @@ type Config struct { KeepAlive time.Duration // Leave at default. PacketSize uint16 - // Kerberos authentication fields - // Path to krb5.conf file that contains Kerberos configuration information - Krb5ConfFilePath string + // Kerberos configuration details + Krb5Conf *config.Config - // Credential cache path - KrbCachePath string + // Credential cache + KrbCache *credentials.CCache // A Kerberos realm is the domain over which a Kerberos authentication server has the authority // to authenticate a user, host or service. @@ -89,8 +92,8 @@ type Config struct { // Flag to authenticate using keytab file Initkrbwithkeytab bool - // Path to keytab file that stores long-term keys for one or more principals - KeytabFilePath string + // Kerberos keytab that stores long-term keys for one or more principals + KrbKeytab *keytab.Keytab // Flag to enable kerberos authentication EnableKerberos bool @@ -217,12 +220,21 @@ func Parse(dsn string) (Config, map[string]string, error) { krbCache, ok := params["krbcache"] if ok { - p.KrbCachePath = krbCache + var err error + p.KrbCache, err = setupKerbCache(krbCache) + if err != nil { + return p, params, fmt.Errorf("cannot read kerberos cache file: %v", err) + } } krb5ConfFile, ok := params["krb5conffile"] if ok { - p.Krb5ConfFilePath = krb5ConfFile + var err error + p.Krb5Conf, err = setupKerbConfig(krb5ConfFile) + if err != nil { + return p, params, fmt.Errorf("cannot read kerberos configuration file: %v", err) + } + } initkrbwithkeytab, ok := params["initkrbwithkeytab"] @@ -237,9 +249,12 @@ func Parse(dsn string) (Config, map[string]string, error) { keytabfile, ok := params["keytabfile"] if ok { - p.KeytabFilePath = keytabfile + var err error + p.KrbKeytab, err = setupKerbKeytab(keytabfile) + if err != nil { + return p, params, fmt.Errorf("cannot read kerberos keytab file: %v", err) + } } - } // https://msdn.microsoft.com/en-us/library/dd341108.aspx @@ -688,7 +703,6 @@ func resolveServerPort(port uint64) uint64 { if port == 0 { return defaultServerPort } - return port } @@ -698,3 +712,35 @@ func generateSpn(host string, port uint64, realm string) string { } return fmt.Sprintf("MSSQLSvc/%s:%d@%s", host, port, realm) } + +func setupKerbConfig(krb5configPath string) (*config.Config, error) { + krb5CnfFile, err := os.Open(krb5configPath) + if err != nil { + return nil, err + } + c, err := config.NewFromReader(krb5CnfFile) + if err != nil { + return nil, err + } + return c, nil +} + +func setupKerbCache(kerbCCahePath string) (*credentials.CCache, error) { + cache, err := credentials.LoadCCache(kerbCCahePath) + if err != nil { + return nil, err + } + return cache, nil +} + +func setupKerbKeytab(keytabFilePath string) (*keytab.Keytab, error) { + var kt = &keytab.Keytab{} + keytabConf, err := ioutil.ReadFile(keytabFilePath) + if err != nil { + return nil, err + } + if err = kt.Unmarshal([]byte(keytabConf)); err != nil { + return nil, err + } + return kt, nil +} diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 53ef6d29..3de60363 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -1,6 +1,9 @@ package msdsn import ( + "io/ioutil" + "os" + "path/filepath" "reflect" "testing" "time" @@ -53,6 +56,7 @@ func TestValidConnectionString(t *testing.T) { {"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p Config) bool { return p.Host == "server" && p.Instance == "instance" && p.User == "tester" && p.Password == "pwd" }}, + {"server=.", func(p Config) bool { return p.Host == "localhost" }}, {"server=(local)", func(p Config) bool { return p.Host == "localhost" }}, {"ServerSPN=serverspn;Workstation ID=workstid", func(p Config) bool { return p.ServerSPN == "serverspn" && p.Workstation == "workstid" }}, @@ -198,7 +202,6 @@ func TestConnParseRoundTripFixed(t *testing.T) { } func TestInvalidConnectionStringKerberos(t *testing.T) { - connStrings := []string{ "server=server;port=1345;realm=domain;trustservercertificate=true;keytabfile=/path/to/administrator2.keytab;enablekerberos=true", "server=server;port=1345;realm=domain;trustservercertificate=true;krbcache=;enablekerberos=true", @@ -219,18 +222,46 @@ func TestInvalidConnectionStringKerberos(t *testing.T) { } func TestValidConnectionStringKerberos(t *testing.T) { + krbcache := createKrbFile("krbcache_1000", t) + keytab := createKrbFile("admin.keytab", t) + krbconf := createKrbFile("krb5.conf", t) connStrings := []string{ - "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/admin.keytab;enablekerberos=true;initkrbwithkeytab=true", - "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;krbcache=/tmp/krb5cc_1000;enablekerberos=true", + "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=" + krbconf + ";keytabfile=" + keytab + ";enablekerberos=true;initkrbwithkeytab=true", + "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=" + krbconf + ";krbcache=" + krbcache + ";enablekerberos=true", } for _, connStr := range connStrings { _, _, err := Parse(connStr) if err == nil { - t.Logf("Connection string was parsed successfully %s", connStrings) - } else { - t.Errorf("Connection string %s failed to parse with error %s", connStrings, err) + t.Errorf("Connection string %s should fail to parse with error %s", connStrings, err) } } + deleteFile(krbcache, t) + deleteFile(krbconf, t) + deleteFile(keytab, t) +} + +func createKrbFile(filename string, t *testing.T) string { + file := []byte("This is a test file") + err := ioutil.WriteFile(filename, file, 0644) + if err != nil { + t.Errorf("Could not write file") + } + filedirectory := filepath.Dir(filename) + thepath, _ := filepath.Abs(filedirectory) + filePath := thepath + "/" + filename + + return filePath +} + +func deleteFile(filename string, t *testing.T) { + defer func() { + if _, err := os.Stat(filename); err == nil { + err = os.Remove(filename) + if err != nil { + t.Errorf("Could not delete file: %v", filename) + } + } + }() } diff --git a/tds.go b/tds.go index a453b852..e27bf020 100644 --- a/tds.go +++ b/tds.go @@ -1011,6 +1011,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont if err != nil { return nil, err } + l.OptionFlags2 |= fIntSecurity return l, nil @@ -1160,9 +1161,9 @@ initiate_connection: var authOk bool if p.EnableKerberos { if p.Initkrbwithkeytab { - auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Krb5ConfFilePath, p.KeytabFilePath, p.Password, p.Initkrbwithkeytab) + auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Krb5Conf, p.KrbKeytab, p.KrbCache, p.Initkrbwithkeytab) } else { - auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Krb5ConfFilePath, p.KrbCachePath, p.Password, p.Initkrbwithkeytab) + auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Krb5Conf, p.KrbKeytab, p.KrbCache, p.Initkrbwithkeytab) } } else { auth, authOk = getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) From eca4758fe85338942f2b502ea0d0fce2bbc8fda2 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Thu, 9 Dec 2021 13:35:12 +0530 Subject: [PATCH 08/21] worked on error message --- kerbauth.go | 2 +- msdsn/conn_str.go | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/kerbauth.go b/kerbauth.go index 74804da4..8699e4eb 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -121,7 +121,7 @@ func (auth *krb5Auth) Free() { func (auth *krb5Auth) NextBytes(token []byte) ([]byte, error) { var spnegoToken spnego.SPNEGOToken if err := spnegoToken.Unmarshal(token); err != nil { - err := fmt.Errorf("unmarshal APRep token failed: %w", err) + err := fmt.Errorf("unmarshal APRep token failed: %v", err) return []byte{}, err } auth.state = initiatorReady diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 0b4eb708..2ee92ac6 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -203,14 +203,13 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.EnableKerberos, err = strconv.ParseBool(enablekerberos) if err != nil { - f := "invalid enablekerberos flag '%v': %v" - return p, params, fmt.Errorf(f, enablekerberos, err.Error()) + return p, params, fmt.Errorf("invalid enablekerberos flag '%v': %v", enablekerberos, err.Error()) } } if p.EnableKerberos { missingParam := checkMissingKRBConfig(params) if missingParam != "" { - return p, params, fmt.Errorf(" %s cannot be empty", missingParam) + return p, params, fmt.Errorf("missing parameter:%s", missingParam) } realm, ok := params["realm"] @@ -223,7 +222,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.KrbCache, err = setupKerbCache(krbCache) if err != nil { - return p, params, fmt.Errorf("cannot read kerberos cache file: %v", err) + return p, params, fmt.Errorf("cannot read kerberos cache file: %v", err.Error()) } } @@ -232,7 +231,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.Krb5Conf, err = setupKerbConfig(krb5ConfFile) if err != nil { - return p, params, fmt.Errorf("cannot read kerberos configuration file: %v", err) + return p, params, fmt.Errorf("cannot read kerberos configuration file: %v", err.Error()) } } @@ -242,8 +241,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.Initkrbwithkeytab, err = strconv.ParseBool(initkrbwithkeytab) if err != nil { - f := "invalid initkrbwithkeytab flag '%v': %v" - return p, params, fmt.Errorf(f, initkrbwithkeytab, err.Error()) + return p, params, fmt.Errorf("invalid initkrbwithkeytab flag '%v': %v", initkrbwithkeytab, err.Error()) } } @@ -252,7 +250,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.KrbKeytab, err = setupKerbKeytab(keytabfile) if err != nil { - return p, params, fmt.Errorf("cannot read kerberos keytab file: %v", err) + return p, params, fmt.Errorf("cannot read kerberos keytab file: %v", err.Error()) } } } @@ -410,9 +408,11 @@ func checkMissingKRBConfig(c map[string]string) (missingParam string) { if c["initkrbwithkeytab"] == "true" { if c["keytabfile"] == "" { missingParam = "keytabfile" + return } if c["realm"] == "" { missingParam = "realm" + return } } else if c["krbcache"] == "" { missingParam = "krbcache" From c47a35b5b300cc88ab3671b7856365f88d958f64 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Mon, 13 Dec 2021 10:11:37 +0530 Subject: [PATCH 09/21] worked on error messages. --- kerbauth.go | 2 +- msdsn/conn_str.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kerbauth.go b/kerbauth.go index 8699e4eb..74804da4 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -121,7 +121,7 @@ func (auth *krb5Auth) Free() { func (auth *krb5Auth) NextBytes(token []byte) ([]byte, error) { var spnegoToken spnego.SPNEGOToken if err := spnegoToken.Unmarshal(token); err != nil { - err := fmt.Errorf("unmarshal APRep token failed: %v", err) + err := fmt.Errorf("unmarshal APRep token failed: %w", err) return []byte{}, err } auth.state = initiatorReady diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 2ee92ac6..8a83bfbf 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -203,7 +203,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.EnableKerberos, err = strconv.ParseBool(enablekerberos) if err != nil { - return p, params, fmt.Errorf("invalid enablekerberos flag '%v': %v", enablekerberos, err.Error()) + return p, params, fmt.Errorf("invalid enablekerberos flag '%v': %w", enablekerberos, err) } } if p.EnableKerberos { @@ -222,7 +222,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.KrbCache, err = setupKerbCache(krbCache) if err != nil { - return p, params, fmt.Errorf("cannot read kerberos cache file: %v", err.Error()) + return p, params, fmt.Errorf("cannot read kerberos cache file: %w", err) } } @@ -231,7 +231,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.Krb5Conf, err = setupKerbConfig(krb5ConfFile) if err != nil { - return p, params, fmt.Errorf("cannot read kerberos configuration file: %v", err.Error()) + return p, params, fmt.Errorf("cannot read kerberos configuration file: %w", err) } } @@ -241,7 +241,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.Initkrbwithkeytab, err = strconv.ParseBool(initkrbwithkeytab) if err != nil { - return p, params, fmt.Errorf("invalid initkrbwithkeytab flag '%v': %v", initkrbwithkeytab, err.Error()) + return p, params, fmt.Errorf("invalid initkrbwithkeytab flag '%v': %w", initkrbwithkeytab, err) } } @@ -250,7 +250,7 @@ func Parse(dsn string) (Config, map[string]string, error) { var err error p.KrbKeytab, err = setupKerbKeytab(keytabfile) if err != nil { - return p, params, fmt.Errorf("cannot read kerberos keytab file: %v", err.Error()) + return p, params, fmt.Errorf("cannot read kerberos keytab file: %w", err) } } } From 27f50f15422458f0b4345a1c73ab493f4a075b87 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Mon, 20 Dec 2021 16:58:44 +0530 Subject: [PATCH 10/21] updated readme. --- README.md | 251 +++++++++++++++++++++++++++++------------------------- 1 file changed, 134 insertions(+), 117 deletions(-) diff --git a/README.md b/README.md index 92e74b65..fb18c192 100644 --- a/README.md +++ b/README.md @@ -18,99 +18,116 @@ Other supported formats are listed below. ### Common parameters -* `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. The user domain sensitive to the case which is defined in the connection string. -* `password` -* `database` -* `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts. -* `dial timeout` - in seconds (default is 15), set to 0 for no timeout -* `encrypt` - * `disable` - Data send between client and server is not encrypted. - * `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) - * `true` - Data sent between client and server is encrypted. -* `app name` - The application name (default is go-mssqldb) +- `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. The user domain sensitive to the case which is defined in the connection string. +- `password` +- `database` +- `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts. +- `dial timeout` - in seconds (default is 15), set to 0 for no timeout +- `encrypt` + - `disable` - Data send between client and server is not encrypted. + - `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) + - `true` - Data sent between client and server is encrypted. +- `app name` - The application name (default is go-mssqldb) + +### Kerberos Parameters + +- `enablekerberos`-It is a boolean flag to enable kerberos authentication mechanism. +- `krb5conffile`-File path for kerberos configuration file. +- `realm`-Domain name for kerberos authentication. +- `initkrbwithkeytab`-It is a boolean flag to enable kerberos authentication using keytab file. +- `keytabfile`-Keytab file path. +- `krbcache`-Credential cache path. ### Connection parameters for ODBC and ADO style connection strings -* `server` - host or host\instance (default localhost) -* `port` - used only when there is no instance in server (default 1433) +- `server` - host or host\instance (default localhost) +- `port` - used only when there is no instance in server (default 1433) ### Less common parameters -* `keepAlive` - in seconds; 0 to disable (default is 30) -* `failoverpartner` - host or host\instance (default is no partner). -* `failoverport` - used only when there is no instance in failoverpartner (default 1433) -* `packet size` - in bytes; 512 to 32767 (default is 4096) - * Encrypted connections have a maximum packet size of 16383 bytes - * Further information on usage: -* `log` - logging flags (default 0/no logging, 63 for full logging) - * 1 log errors - * 2 log messages - * 4 log rows affected - * 8 trace sql statements - * 16 log statement parameters - * 32 log transaction begin/end -* `TrustServerCertificate` - * false - Server certificate is checked. Default is false if encypt is specified. - * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. -* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. -* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. -* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. -* `Workstation ID` - The workstation name (default is the host name) -* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. +- `keepAlive` - in seconds; 0 to disable (default is 30) +- `failoverpartner` - host or host\instance (default is no partner). +- `failoverport` - used only when there is no instance in failoverpartner (default 1433) +- `packet size` - in bytes; 512 to 32767 (default is 4096) + - Encrypted connections have a maximum packet size of 16383 bytes + - Further information on usage: +- `log` - logging flags (default 0/no logging, 63 for full logging) + - 1 log errors + - 2 log messages + - 4 log rows affected + - 8 trace sql statements + - 16 log statement parameters + - 32 log transaction begin/end +- `TrustServerCertificate` + - false - Server certificate is checked. Default is false if encypt is specified. + - true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. +- `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. +- `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. +- `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. +- `Workstation ID` - The workstation name (default is the host name) +- `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. ### The connection string can be specified in one of three formats 1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as - the first segment in the path. All other options are query parameters. Examples: - - * `sqlserver://username:password@host/instance?param1=value¶m2=value` - * `sqlserver://username:password@host:port?param1=value¶m2=value` - * `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance. - * `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass. - * `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost. - * `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass" - A string of this format can be constructed using the `URL` type in the `net/url` package. - - ```go - - query := url.Values{} - query.Add("app name", "MyAppName") - - u := &url.URL{ - Scheme: "sqlserver", - User: url.UserPassword(username, password), - Host: fmt.Sprintf("%s:%d", hostname, port), - // Path: instance, // if connecting to an instance instead of a port - RawQuery: query.Encode(), - } - db, err := sql.Open("sqlserver", u.String()) - - ``` + the first segment in the path. All other options are query parameters. Examples: + + - `sqlserver://username:password@host/instance?param1=value¶m2=value` + - `sqlserver://username:password@host:port?param1=value¶m2=value` + - `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance. + - `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass. + - `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost. + - `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass" + A string of this format can be constructed using the `URL` type in the `net/url` package. + + ```go + + query := url.Values{} + query.Add("app name", "MyAppName") + + u := &url.URL{ + Scheme: "sqlserver", + User: url.UserPassword(username, password), + Host: fmt.Sprintf("%s:%d", hostname, port), + // Path: instance, // if connecting to an instance instead of a port + RawQuery: query.Encode(), + } + db, err := sql.Open("sqlserver", u.String()) + + ``` + + - `sqlserver://username@host/instance?enablekerberos=true&krb5conffile=path/to/file&initkrbwithkeytab=false&krbcache=/path/to/cache` + - `sqlserver://username@host/instance?enablekerberos=true&krb5conffile=path/to/file&realm=domain.com&initkrbwithkeytab=true&keytabfile=/path/to/keytabfile` 2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored. - Examples: + Examples: - * `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` - * `server=localhost;user id=sa;database=master;app name=MyAppName` + - `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` + - `server=localhost;user id=sa;database=master;app name=MyAppName` + - `server=localhost;user id=sa;database=master;app name=MyAppName;enablekerberos=true;krb5conffile=path/to/file;initkrbwithkeytab=false;krbcache=path/to/cache` + - `server=localhost;user id=sa;database=master;app name=MyAppName;enablekerberos=true;krb5conffile=path/to/file;realm=domain.com;initkrbwithkeytab=true;keytabfile=path/to/keytabfile` - ADO strings support synonyms for database, app name, user id, and server - * server <= addr, address, network address, data source - * user id <= user, uid - * database <= initial catalog - * app name <= application name + ADO strings support synonyms for database, app name, user id, and server + + - server <= addr, address, network address, data source + - user id <= user, uid + - database <= initial catalog + - app name <= application name 3. ODBC: Prefix with `odbc`, `key=value` pairs separated by `;`. Allow `;` by wrapping - values in `{}`. Examples: - - * `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` - * `odbc:server=localhost;user id=sa;database=master;app name=MyAppName` - * `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar" - * `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar" - * `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar " - * `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar" - * `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar" - * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" - * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with`}}`, password is "foo}bar" + values in `{}`. Examples: + + - `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` + - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName` + - `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar" + - `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar" + - `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar " + - `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar" + - `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar" + - `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" + - `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with`}}`, password is "foo}bar" + - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;enablekerberos=true;krb5conffile=path/to/file;initkrbwithkeytab=false;krbcache=path/to/cache` + - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;enablekerberos=true;krb5conffile=path/to/file;initkrbwithkeytab=true;realm=domain.com;keytabfile=path/to/keytabfile` ### Azure Active Directory authentication - preview @@ -120,7 +137,7 @@ Azure Active Directory (AAD) access tokens are relatively short lived and need t valid when a new connection is made. Authentication is supported using a callback func that provides a fresh and valid token using a connector: -``` go +```go conn, err := mssql.NewAccessTokenConnector( "Server=test.database.windows.net;Database=testdb", @@ -257,43 +274,43 @@ To pass specific types to the query parameters, say `varchar` or `date` types, you must convert the types to the type before passing in. The following types are supported: -* string -> nvarchar -* mssql.VarChar -> varchar -* time.Time -> datetimeoffset or datetime (TDS version dependent) -* mssql.DateTime1 -> datetime -* mssql.DateTimeOffset -> datetimeoffset -* "github.com/golang-sql/civil".Date -> date -* "github.com/golang-sql/civil".DateTime -> datetime2 -* "github.com/golang-sql/civil".Time -> time -* mssql.TVP -> Table Value Parameter (TDS version dependent) +- string -> nvarchar +- mssql.VarChar -> varchar +- time.Time -> datetimeoffset or datetime (TDS version dependent) +- mssql.DateTime1 -> datetime +- mssql.DateTimeOffset -> datetimeoffset +- "github.com/golang-sql/civil".Date -> date +- "github.com/golang-sql/civil".DateTime -> datetime2 +- "github.com/golang-sql/civil".Time -> time +- mssql.TVP -> Table Value Parameter (TDS version dependent) ## Important Notes -* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should - not be used with this driver (or SQL Server) due to how the TDS protocol - works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) - or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your - query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)). - This will ensure you are getting the correct ID and will prevent a network round trip. -* [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector) - may be used with [OpenDB](https://golang.org/pkg/database/sql/#OpenDB). -* [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL) - may be set to set any driver specific session settings after the session - has been reset. If empty the session will still be reset but use the database - defaults in Go1.10+. +- [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should + not be used with this driver (or SQL Server) due to how the TDS protocol + works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) + or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your + query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)). + This will ensure you are getting the correct ID and will prevent a network round trip. +- [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector) + may be used with [OpenDB](https://golang.org/pkg/database/sql/#OpenDB). +- [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL) + may be set to set any driver specific session settings after the session + has been reset. If empty the session will still be reset but use the database + defaults in Go1.10+. ## Features -* Can be used with SQL Server 2005 or newer -* Can be used with Microsoft Azure SQL Database -* Can be used on all go supported platforms (e.g. Linux, Mac OS X and Windows) -* Supports new date/time types: date, time, datetime2, datetimeoffset -* Supports string parameters longer than 8000 characters -* Supports encryption using SSL/TLS -* Supports SQL Server and Windows Authentication -* Supports Single-Sign-On on Windows -* Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas. -* Supports query notifications +- Can be used with SQL Server 2005 or newer +- Can be used with Microsoft Azure SQL Database +- Can be used on all go supported platforms (e.g. Linux, Mac OS X and Windows) +- Supports new date/time types: date, time, datetime2, datetimeoffset +- Supports string parameters longer than 8000 characters +- Supports encryption using SSL/TLS +- Supports SQL Server and Windows Authentication +- Supports Single-Sign-On on Windows +- Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas. +- Supports query notifications ## Tests @@ -315,10 +332,10 @@ These features still exist in the driver, but they are are deprecated. If you use the driver name "mssql" (rather then "sqlserver") the SQL text will be loosly parsed and an attempt to extract identifiers using one of -* ? -* ?nnn -* :nnn -* $nnn +- ? +- ?nnn +- :nnn +- $nnn will be used. This is not recommended with SQL Server. There is at least one existing `won't fix` issue with the query parsing. @@ -327,7 +344,7 @@ Use the native "@Name" parameters instead with the "sqlserver" driver name. ## Known Issues -* SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled. -To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2. -To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3. -More information: +- SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled. + To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2. + To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3. + More information: From ca67c066df37fdeeb4a40df6127483421d1259fc Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Thu, 20 Jan 2022 12:12:31 +0530 Subject: [PATCH 11/21] Wroked on review comments --- README.md | 15 +++++++-------- msdsn/conn_str.go | 3 --- msdsn/conn_str_test.go | 12 +++--------- tds.go | 2 +- 4 files changed, 11 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 2956ad1f..b7613d94 100644 --- a/README.md +++ b/README.md @@ -31,10 +31,8 @@ Other supported formats are listed below. ### Kerberos Parameters -- `enablekerberos`-It is a boolean flag to enable kerberos authentication mechanism. - `krb5conffile`-File path for kerberos configuration file. - `realm`-Domain name for kerberos authentication. -- `initkrbwithkeytab`-It is a boolean flag to enable kerberos authentication using keytab file. - `keytabfile`-Keytab file path. - `krbcache`-Credential cache path. @@ -96,16 +94,16 @@ Other supported formats are listed below. ``` - - `sqlserver://username@host/instance?enablekerberos=true&krb5conffile=path/to/file&initkrbwithkeytab=false&krbcache=/path/to/cache` - - `sqlserver://username@host/instance?enablekerberos=true&krb5conffile=path/to/file&realm=domain.com&initkrbwithkeytab=true&keytabfile=/path/to/keytabfile` + - `sqlserver://username@host/instance?krb5conffile=path/to/file&krbcache=/path/to/cache` + - `sqlserver://username@host/instance?krb5conffile=path/to/file&realm=domain.com&keytabfile=/path/to/keytabfile` 2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored. Examples: - `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` - `server=localhost;user id=sa;database=master;app name=MyAppName` - - `server=localhost;user id=sa;database=master;app name=MyAppName;enablekerberos=true;krb5conffile=path/to/file;initkrbwithkeytab=false;krbcache=path/to/cache` - - `server=localhost;user id=sa;database=master;app name=MyAppName;enablekerberos=true;krb5conffile=path/to/file;realm=domain.com;initkrbwithkeytab=true;keytabfile=path/to/keytabfile` + - `server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;krbcache=path/to/cache` + - `server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;realm=domain.com;keytabfile=path/to/keytabfile` ADO strings support synonyms for database, app name, user id, and server @@ -126,8 +124,8 @@ Other supported formats are listed below. - `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar" - `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" - `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with`}}`, password is "foo}bar" - - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;enablekerberos=true;krb5conffile=path/to/file;initkrbwithkeytab=false;krbcache=path/to/cache` - - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;enablekerberos=true;krb5conffile=path/to/file;initkrbwithkeytab=true;realm=domain.com;keytabfile=path/to/keytabfile` + - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;krbcache=path/to/cache` + - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;realm=domain.com;keytabfile=path/to/keytabfile` ### Azure Active Directory authentication @@ -324,6 +322,7 @@ are supported: - Supports Single-Sign-On on Windows - Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas. - Supports query notifications +- Supports Kerberos Authentication ## Tests diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index ae1951cc..b8f89503 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -96,9 +96,6 @@ type Config struct { PacketSize uint16 Kerberos *Kerberos - - // Flag to enable kerberos authentication - EnableKerberos bool } func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string) (*tls.Config, error) { diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index ac35b128..84f352dc 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -203,20 +203,15 @@ func TestConnParseRoundTripFixed(t *testing.T) { func TestInvalidConnectionStringKerberos(t *testing.T) { connStrings := []string{ - "server=server;port=1345;realm=domain;trustservercertificate=true;keytabfile=/path/to/administrator2.keytab;enablekerberos=true", - "server=server;port=1345;realm=domain;trustservercertificate=true;krbcache=;enablekerberos=true", - "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;enablekerberos=true", - "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/administrator2.keytab;enablekerberos=true", - "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/administrator2.keytab;enablekerberos=true;initkrbwithkeytab=false", - "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;enablekerberos=true;initkrbwithkeytab=true", + "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;", + "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;", + "server=server;user id=user;password=pwd;port=1345;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/administrator2.keytab;", } for _, connStr := range connStrings { _, _, err := Parse(connStr) if err == nil { t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr) continue - } else { - t.Logf("Connection failed for %s as expected with error %v", connStr, err) } } } @@ -240,7 +235,6 @@ func TestValidConnectionStringKerberos(t *testing.T) { func createKrbFile(filename string, t *testing.T) string { if _, err := os.Stat("temp"); os.IsNotExist(err) { err := os.Mkdir("temp", 0755) - // TODO: handle error if err != nil { t.Errorf("Failed to create a temporary directory") } diff --git a/tds.go b/tds.go index 0944a7f0..228d6744 100644 --- a/tds.go +++ b/tds.go @@ -1166,7 +1166,7 @@ initiate_connection: } var auth auth var authOk bool - if p.Kerberos != nil { + if p.Kerberos.Krb5Conf != nil { auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Kerberos.Krb5Conf, p.Kerberos.KrbKeytab, p.Kerberos.KrbCache) } else { auth, authOk = getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) From 6e619bee49adcce1a4d02fed2f76152b797eba5f Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Thu, 20 Jan 2022 13:19:48 +0530 Subject: [PATCH 12/21] fixed the changes for readme. --- README.md | 290 ++++++++++++++++++++++++++-------------------------- kerbauth.go | 2 - 2 files changed, 144 insertions(+), 148 deletions(-) diff --git a/README.md b/README.md index b7613d94..1be09966 100644 --- a/README.md +++ b/README.md @@ -18,114 +18,112 @@ Other supported formats are listed below. ### Common parameters -- `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. The user domain sensitive to the case which is defined in the connection string. -- `password` -- `database` -- `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts. -- `dial timeout` - in seconds (default is 15), set to 0 for no timeout -- `encrypt` - - `disable` - Data send between client and server is not encrypted. - - `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) - - `true` - Data sent between client and server is encrypted. -- `app name` - The application name (default is go-mssqldb) - -### Kerberos Parameters - -- `krb5conffile`-File path for kerberos configuration file. -- `realm`-Domain name for kerberos authentication. -- `keytabfile`-Keytab file path. -- `krbcache`-Credential cache path. +* `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. The user domain sensitive to the case which is defined in the connection string. +* `password` +* `database` +* `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts. +* `dial timeout` - in seconds (default is 15), set to 0 for no timeout +* `encrypt` + * `disable` - Data send between client and server is not encrypted. + * `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) + * `true` - Data sent between client and server is encrypted. +* `app name` - The application name (default is go-mssqldb) ### Connection parameters for ODBC and ADO style connection strings -- `server` - host or host\instance (default localhost) -- `port` - used only when there is no instance in server (default 1433) +* `server` - host or host\instance (default localhost) +* `port` - used only when there is no instance in server (default 1433) ### Less common parameters -- `keepAlive` - in seconds; 0 to disable (default is 30) -- `failoverpartner` - host or host\instance (default is no partner). -- `failoverport` - used only when there is no instance in failoverpartner (default 1433) -- `packet size` - in bytes; 512 to 32767 (default is 4096) - - Encrypted connections have a maximum packet size of 16383 bytes - - Further information on usage: -- `log` - logging flags (default 0/no logging, 63 for full logging) - - 1 log errors - - 2 log messages - - 4 log rows affected - - 8 trace sql statements - - 16 log statement parameters - - 32 log transaction begin/end -- `TrustServerCertificate` - - false - Server certificate is checked. Default is false if encypt is specified. - - true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. -- `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. -- `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. -- `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. -- `Workstation ID` - The workstation name (default is the host name) -- `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. - -### The connection string can be specified in one of three formats - -1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as - the first segment in the path. All other options are query parameters. Examples: - - - `sqlserver://username:password@host/instance?param1=value¶m2=value` - - `sqlserver://username:password@host:port?param1=value¶m2=value` - - `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance. - - `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass. - - `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost. - - `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass" - A string of this format can be constructed using the `URL` type in the `net/url` package. - - ```go +* `keepAlive` - in seconds; 0 to disable (default is 30) +* `failoverpartner` - host or host\instance (default is no partner). +* `failoverport` - used only when there is no instance in failoverpartner (default 1433) +* `packet size` - in bytes; 512 to 32767 (default is 4096) + * Encrypted connections have a maximum packet size of 16383 bytes + * Further information on usage: +* `log` - logging flags (default 0/no logging, 63 for full logging) + * 1 log errors + * 2 log messages + * 4 log rows affected + * 8 trace sql statements + * 16 log statement parameters + * 32 log transaction begin/end +* `TrustServerCertificate` + * false - Server certificate is checked. Default is false if encypt is specified. + * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. +* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. +* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. +* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. +* `Workstation ID` - The workstation name (default is the host name) +* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. - query := url.Values{} - query.Add("app name", "MyAppName") - - u := &url.URL{ - Scheme: "sqlserver", - User: url.UserPassword(username, password), - Host: fmt.Sprintf("%s:%d", hostname, port), - // Path: instance, // if connecting to an instance instead of a port - RawQuery: query.Encode(), - } - db, err := sql.Open("sqlserver", u.String()) +### Kerberos Parameters - ``` +* `krb5conffile` - File path for kerberos configuration file. +* `realm` - Domain name for kerberos authentication. +* `keytabfile` - Keytab file path. +* `krbcache` - Credential cache path. - - `sqlserver://username@host/instance?krb5conffile=path/to/file&krbcache=/path/to/cache` - - `sqlserver://username@host/instance?krb5conffile=path/to/file&realm=domain.com&keytabfile=/path/to/keytabfile` +### The connection string can be specified in one of three formats +1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as + the first segment in the path. All other options are query parameters. Examples: + + * `sqlserver://username:password@host/instance?param1=value¶m2=value` + * `sqlserver://username:password@host:port?param1=value¶m2=value` + * `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance. + * `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass. + * `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost. + * `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass" + A string of this format can be constructed using the `URL` type in the `net/url` package. + + ```go + + query := url.Values{} + query.Add("app name", "MyAppName") + + u := &url.URL{ + Scheme: "sqlserver", + User: url.UserPassword(username, password), + Host: fmt.Sprintf("%s:%d", hostname, port), + // Path: instance, // if connecting to an instance instead of a port + RawQuery: query.Encode(), + } + db, err := sql.Open("sqlserver", u.String()) + + ``` + * `sqlserver://username@host/instance?krb5conffile=path/to/file&krbcache=/path/to/cache` + * `sqlserver://username@host/instance?krb5conffile=path/to/file&realm=domain.com&keytabfile=/path/to/keytabfile` + 2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored. - Examples: - - - `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` - - `server=localhost;user id=sa;database=master;app name=MyAppName` - - `server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;krbcache=path/to/cache` - - `server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;realm=domain.com;keytabfile=path/to/keytabfile` + Examples: - ADO strings support synonyms for database, app name, user id, and server + * `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` + * `server=localhost;user id=sa;database=master;app name=MyAppName` + * `server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;krbcache=path/to/cache` + * `server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;realm=domain.com;keytabfile=path/to/keytabfile` - - server <= addr, address, network address, data source - - user id <= user, uid - - database <= initial catalog - - app name <= application name + ADO strings support synonyms for database, app name, user id, and server + * server <= addr, address, network address, data source + * user id <= user, uid + * database <= initial catalog + * app name <= application name 3. ODBC: Prefix with `odbc`, `key=value` pairs separated by `;`. Allow `;` by wrapping - values in `{}`. Examples: - - - `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` - - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName` - - `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar" - - `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar" - - `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar " - - `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar" - - `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar" - - `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" - - `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with`}}`, password is "foo}bar" - - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;krbcache=path/to/cache` - - `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;realm=domain.com;keytabfile=path/to/keytabfile` + values in `{}`. Examples: + + * `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` + * `odbc:server=localhost;user id=sa;database=master;app name=MyAppName` + * `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar" + * `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar " + * `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar" + * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with`}}`, password is "foo}bar" + * `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;krbcache=path/to/cache` + * `odbc:server=localhost;user id=sa;database=master;app name=MyAppName;krb5conffile=path/to/file;realm=domain.com;keytabfile=path/to/keytabfile` ### Azure Active Directory authentication @@ -134,19 +132,19 @@ The `mssql` package does not provide an implementation to obtain tokens: instead The credential type is determined by the new `fedauth` connection string parameter. -- `fedauth=ActiveDirectoryServicePrincipal` or `fedauth=ActiveDirectoryApplication` - authenticates using an Azure Active Directory application client ID and client secret or certificate. Implemented using [ClientSecretCredential or CertificateCredential](https://github.com/Azure/azure-sdk-for-go/tree/main/sdk/azidentity#authenticating-service-principals) - - `clientcertpath=;password=` or - - `password=` - - `user id=[@tenantid]` Note the `@tenantid` component can be omitted if the server's tenant is the same as the application's tenant. -- `fedauth=ActiveDirectoryPassword` - authenticates using a user name and password. - - `user id=username@domain` - - `password=` - - `applicationclientid=` - This guid identifies an Azure Active Directory enterprise application that the AAD admin has approved for accessing Azure SQL database resources in the tenant. This driver does not have an associated application id of its own. -- `fedauth=ActiveDirectoryDefault` - authenticates using a chained set of credentials. The chain is built from EnvironmentCredential -> ManagedIdentityCredential->AzureCLICredential. See [DefaultAzureCredential docs](https://github.com/Azure/azure-sdk-for-go/wiki/Set-up-Your-Environment-for-Authentication#configure-defaultazurecredential) for instructions on setting up your host environment to use it. Using this option allows you to have the same connection string in a service deployment as on your interactive development machine. -- `fedauth=ActiveDirectoryManagedIdentity` or `fedauth=ActiveDirectoryMSI` - authenticates using a system-assigned or user-assigned Azure Managed Identity. - - `user id=` - optional id of user-assigned managed identity. If empty, system-assigned managed identity is used. -- `fedauth=ActiveDirectoryInteractive` - authenticates using credentials acquired from an external web browser. Only suitable for use with human interaction. - - `applicationclientid=` - This guid identifies an Azure Active Directory enterprise application that the AAD admin has approved for accessing Azure SQL database resources in the tenant. This driver does not have an associated application id of its own. +* `fedauth=ActiveDirectoryServicePrincipal` or `fedauth=ActiveDirectoryApplication` - authenticates using an Azure Active Directory application client ID and client secret or certificate. Implemented using [ClientSecretCredential or CertificateCredential](https://github.com/Azure/azure-sdk-for-go/tree/main/sdk/azidentity#authenticating-service-principals) + * `clientcertpath=;password=` or + * `password=` + * `user id=[@tenantid]` Note the `@tenantid` component can be omitted if the server's tenant is the same as the application's tenant. +* `fedauth=ActiveDirectoryPassword` - authenticates using a user name and password. + * `user id=username@domain` + * `password=` + * `applicationclientid=` - This guid identifies an Azure Active Directory enterprise application that the AAD admin has approved for accessing Azure SQL database resources in the tenant. This driver does not have an associated application id of its own. +* `fedauth=ActiveDirectoryDefault` - authenticates using a chained set of credentials. The chain is built from EnvironmentCredential -> ManagedIdentityCredential->AzureCLICredential. See [DefaultAzureCredential docs](https://github.com/Azure/azure-sdk-for-go/wiki/Set-up-Your-Environment-for-Authentication#configure-defaultazurecredential) for instructions on setting up your host environment to use it. Using this option allows you to have the same connection string in a service deployment as on your interactive development machine. +* `fedauth=ActiveDirectoryManagedIdentity` or `fedauth=ActiveDirectoryMSI` - authenticates using a system-assigned or user-assigned Azure Managed Identity. + * `user id=` - optional id of user-assigned managed identity. If empty, system-assigned managed identity is used. +* `fedauth=ActiveDirectoryInteractive` - authenticates using credentials acquired from an external web browser. Only suitable for use with human interaction. + * `applicationclientid=` - This guid identifies an Azure Active Directory enterprise application that the AAD admin has approved for accessing Azure SQL database resources in the tenant. This driver does not have an associated application id of its own. ```go @@ -285,44 +283,44 @@ To pass specific types to the query parameters, say `varchar` or `date` types, you must convert the types to the type before passing in. The following types are supported: -- string -> nvarchar -- mssql.VarChar -> varchar -- time.Time -> datetimeoffset or datetime (TDS version dependent) -- mssql.DateTime1 -> datetime -- mssql.DateTimeOffset -> datetimeoffset -- "github.com/golang-sql/civil".Date -> date -- "github.com/golang-sql/civil".DateTime -> datetime2 -- "github.com/golang-sql/civil".Time -> time -- mssql.TVP -> Table Value Parameter (TDS version dependent) +* string -> nvarchar +* mssql.VarChar -> varchar +* time.Time -> datetimeoffset or datetime (TDS version dependent) +* mssql.DateTime1 -> datetime +* mssql.DateTimeOffset -> datetimeoffset +* "github.com/golang-sql/civil".Date -> date +* "github.com/golang-sql/civil".DateTime -> datetime2 +* "github.com/golang-sql/civil".Time -> time +* mssql.TVP -> Table Value Parameter (TDS version dependent) ## Important Notes -- [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should - not be used with this driver (or SQL Server) due to how the TDS protocol - works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) - or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your - query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)). - This will ensure you are getting the correct ID and will prevent a network round trip. -- [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector) - may be used with [OpenDB](https://golang.org/pkg/database/sql/#OpenDB). -- [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL) - may be set to set any driver specific session settings after the session - has been reset. If empty the session will still be reset but use the database - defaults in Go1.10+. +* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should + not be used with this driver (or SQL Server) due to how the TDS protocol + works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) + or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your + query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)). + This will ensure you are getting the correct ID and will prevent a network round trip. +* [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector) + may be used with [OpenDB](https://golang.org/pkg/database/sql/#OpenDB). +* [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL) + may be set to set any driver specific session settings after the session + has been reset. If empty the session will still be reset but use the database + defaults in Go1.10+. ## Features -- Can be used with SQL Server 2005 or newer -- Can be used with Microsoft Azure SQL Database -- Can be used on all go supported platforms (e.g. Linux, Mac OS X and Windows) -- Supports new date/time types: date, time, datetime2, datetimeoffset -- Supports string parameters longer than 8000 characters -- Supports encryption using SSL/TLS -- Supports SQL Server and Windows Authentication -- Supports Single-Sign-On on Windows -- Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas. -- Supports query notifications -- Supports Kerberos Authentication +* Can be used with SQL Server 2005 or newer +* Can be used with Microsoft Azure SQL Database +* Can be used on all go supported platforms (e.g. Linux, Mac OS X and Windows) +* Supports new date/time types: date, time, datetime2, datetimeoffset +* Supports string parameters longer than 8000 characters +* Supports encryption using SSL/TLS +* Supports SQL Server and Windows Authentication +* Supports Single-Sign-On on Windows +* Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas. +* Supports query notifications +* Supports Kerberos Authentication ## Tests @@ -346,10 +344,10 @@ These features still exist in the driver, but they are are deprecated. If you use the driver name "mssql" (rather then "sqlserver") the SQL text will be loosly parsed and an attempt to extract identifiers using one of -- ? -- ?nnn -- :nnn -- $nnn +* ? +* ?nnn +* :nnn +* $nnn will be used. This is not recommended with SQL Server. There is at least one existing `won't fix` issue with the query parsing. @@ -358,7 +356,7 @@ Use the native "@Name" parameters instead with the "sqlserver" driver name. ## Known Issues -- SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled. - To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2. - To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3. - More information: +* SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled. +To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2. +To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3. +More information: \ No newline at end of file diff --git a/kerbauth.go b/kerbauth.go index 01a76f70..475c944b 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -78,8 +78,6 @@ func getKRB5Auth(user, password, serverSPN string, krb5Conf *config.Config, keyt } func (auth *krb5Auth) InitialBytes() ([]byte, error) { - // Set to lookup KDCs in DNS - auth.krb5Config.LibDefaults.DNSLookupKDC = false var cl *client.Client var err error // Init keytab from conf From d4f52ce5060b395a0a1dd989332d35e6ac40be78 Mon Sep 17 00:00:00 2001 From: chandan jain Date: Thu, 20 Jan 2022 14:56:49 +0530 Subject: [PATCH 13/21] fix: whitespace --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1be09966..901393f6 100644 --- a/README.md +++ b/README.md @@ -359,4 +359,4 @@ Use the native "@Name" parameters instead with the "sqlserver" driver name. * SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled. To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2. To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3. -More information: \ No newline at end of file +More information: From 530eb45b0e470342841d8eed62c995bc2202411c Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Thu, 20 Jan 2022 16:55:03 +0530 Subject: [PATCH 14/21] worked on review comments --- msdsn/conn_str.go | 10 +++------- msdsn/conn_str_test.go | 40 +++++++++++++++------------------------- 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index b8f89503..efc39a85 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -132,9 +132,6 @@ var skipSetup = errors.New("skip setting up TLS") func Parse(dsn string) (Config, map[string]string, error) { p := Config{} - k := Kerberos{} - - p.Kerberos = &k var params map[string]string if strings.HasPrefix(dsn, "odbc:") { @@ -209,14 +206,13 @@ func Parse(dsn string) (Config, map[string]string, error) { krb5ConfFile, ok := params["krb5conffile"] if ok { + p.Kerberos = &Kerberos{} var err error p.Kerberos.Krb5Conf, err = setupKerbConfig(krb5ConfFile) if err != nil { return p, params, fmt.Errorf("cannot read kerberos configuration file: %w", err) } - } - if ok { missingParam := checkMissingKRBConfig(params) if missingParam != "" { return p, params, fmt.Errorf("missing parameter:%s", missingParam) @@ -330,7 +326,7 @@ func Parse(dsn string) (Config, map[string]string, error) { if ok { p.ServerSPN = serverSPN } else { - p.ServerSPN = generateSpn(p.Host, resolveServerPort(p.Port), k.KrbRealm) + p.ServerSPN = generateSpn(p.Host, resolveServerPort(p.Port), p.Kerberos.KrbRealm) } workstation, ok := params["workstation id"] @@ -396,7 +392,7 @@ func checkMissingKRBConfig(c map[string]string) (missingParam string) { } } if c["krbcache"] == "" && c["keytabfile"] == "" { - missingParam = "krbcache or keytab" + missingParam = "atleast krbcache or keytab is required" return } return diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 84f352dc..ea098f52 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -3,7 +3,6 @@ package msdsn import ( "io/ioutil" "os" - "path/filepath" "reflect" "testing" "time" @@ -203,9 +202,9 @@ func TestConnParseRoundTripFixed(t *testing.T) { func TestInvalidConnectionStringKerberos(t *testing.T) { connStrings := []string{ - "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;", - "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/etc/krb5.conf;", - "server=server;user id=user;password=pwd;port=1345;trustservercertificate=true;krb5conffile=/etc/krb5.conf;keytabfile=/path/to/administrator2.keytab;", + "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/path/krb5.conf;", + "server=server;user id=user;password=pwd;port=1345;realm=domain;trustservercertificate=true;krb5conffile=/path/krb5.conf;", + "server=server;user id=user;password=pwd;port=1345;trustservercertificate=true;krb5conffile=/path/krb5.conf;keytabfile=/path/to/administrator2.keytab;", } for _, connStr := range connStrings { _, _, err := Parse(connStr) @@ -217,7 +216,7 @@ func TestInvalidConnectionStringKerberos(t *testing.T) { } func TestValidConnectionStringKerberos(t *testing.T) { - kerberosTestFile := createKrbFile("test.txt", t) + kerberosTestFile := createKrbFile(t) connStrings := []string{ "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=" + kerberosTestFile + ";keytabfile=" + kerberosTestFile, "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=" + kerberosTestFile + ";krbcache=" + kerberosTestFile, @@ -232,30 +231,21 @@ func TestValidConnectionStringKerberos(t *testing.T) { deleteFile(t) } -func createKrbFile(filename string, t *testing.T) string { - if _, err := os.Stat("temp"); os.IsNotExist(err) { - err := os.Mkdir("temp", 0755) - if err != nil { - t.Errorf("Failed to create a temporary directory") - } +func createKrbFile(t *testing.T) string { + err := os.Mkdir("temp", 0755) + if err != nil { + t.Errorf("Failed to create a temporary directory") } - file := []byte("This is a test file") - err := ioutil.WriteFile("temp/"+filename, file, 0644) + file, err := ioutil.TempFile("temp", "test-*.txt") if err != nil { - t.Errorf("Could not write file") + t.Errorf("Failed to create a temp file") } - filedirectory := filepath.Dir(filename) - thepath, _ := filepath.Abs(filedirectory) - filePath := thepath + "/" + filename - - return filePath + if _, err := file.Write([]byte("This is a test file\n")); err != nil { + t.Errorf("Failed to write file") + } + return file.Name() } func deleteFile(t *testing.T) { - defer func() { - err := os.RemoveAll("temp") - if err != nil { - t.Errorf("Could not delete directory") - } - }() + os.RemoveAll("temp") } From 85c5bb155d6e72612606dd42df3f49717150ef4c Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Thu, 20 Jan 2022 17:23:47 +0530 Subject: [PATCH 15/21] fix for unit testing --- README.md | 3 +++ msdsn/conn_str.go | 14 +++++++------- msdsn/conn_str_test.go | 15 +++++++-------- tds.go | 2 +- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 901393f6..28769f14 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,9 @@ Other supported formats are listed below. * `realm` - Domain name for kerberos authentication. * `keytabfile` - Keytab file path. * `krbcache` - Credential cache path. +* For further information on usage: + * + * ### The connection string can be specified in one of three formats diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index efc39a85..f9272a14 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -48,14 +48,14 @@ type Kerberos struct { Krb5Conf *config.Config // Credential cache - KrbCache *credentials.CCache + Cache *credentials.CCache // A Kerberos realm is the domain over which a Kerberos authentication server has the authority // to authenticate a user, host or service. - KrbRealm string + Realm string // Kerberos keytab that stores long-term keys for one or more principals - KrbKeytab *keytab.Keytab + Keytab *keytab.Keytab } type Config struct { @@ -220,13 +220,13 @@ func Parse(dsn string) (Config, map[string]string, error) { realm, ok := params["realm"] if ok { - p.Kerberos.KrbRealm = realm + p.Kerberos.Realm = realm } krbCache, ok := params["krbcache"] if ok { var err error - p.Kerberos.KrbCache, err = setupKerbCache(krbCache) + p.Kerberos.Cache, err = setupKerbCache(krbCache) if err != nil { return p, params, fmt.Errorf("cannot read kerberos cache file: %w", err) } @@ -235,7 +235,7 @@ func Parse(dsn string) (Config, map[string]string, error) { keytabfile, ok := params["keytabfile"] if ok { var err error - p.Kerberos.KrbKeytab, err = setupKerbKeytab(keytabfile) + p.Kerberos.Keytab, err = setupKerbKeytab(keytabfile) if err != nil { return p, params, fmt.Errorf("cannot read kerberos keytab file: %w", err) } @@ -326,7 +326,7 @@ func Parse(dsn string) (Config, map[string]string, error) { if ok { p.ServerSPN = serverSPN } else { - p.ServerSPN = generateSpn(p.Host, resolveServerPort(p.Port), p.Kerberos.KrbRealm) + p.ServerSPN = generateSpn(p.Host, resolveServerPort(p.Port), p.Kerberos.Realm) } workstation, ok := params["workstation id"] diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index ea098f52..5c58e876 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -1,6 +1,7 @@ package msdsn import ( + "fmt" "io/ioutil" "os" "reflect" @@ -228,15 +229,13 @@ func TestValidConnectionStringKerberos(t *testing.T) { t.Errorf("Connection string %s should fail to parse with error %s", connStrings, err) } } - deleteFile(t) + deleteFile(kerberosTestFile, t) } func createKrbFile(t *testing.T) string { - err := os.Mkdir("temp", 0755) - if err != nil { - t.Errorf("Failed to create a temporary directory") - } - file, err := ioutil.TempFile("temp", "test-*.txt") + dir := os.TempDir() + fmt.Println(dir) + file, err := ioutil.TempFile(dir, "test-*.txt") if err != nil { t.Errorf("Failed to create a temp file") } @@ -246,6 +245,6 @@ func createKrbFile(t *testing.T) string { return file.Name() } -func deleteFile(t *testing.T) { - os.RemoveAll("temp") +func deleteFile(filename string, t *testing.T) { + os.Remove(filename) } diff --git a/tds.go b/tds.go index 228d6744..deea9522 100644 --- a/tds.go +++ b/tds.go @@ -1167,7 +1167,7 @@ initiate_connection: var auth auth var authOk bool if p.Kerberos.Krb5Conf != nil { - auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Kerberos.Krb5Conf, p.Kerberos.KrbKeytab, p.Kerberos.KrbCache) + auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Kerberos.Krb5Conf, p.Kerberos.Keytab, p.Kerberos.Cache) } else { auth, authOk = getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) } From 04c18be0a5a5c6d947f614af63b3ac58ec3e1d6b Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Mon, 24 Jan 2022 13:56:55 +0530 Subject: [PATCH 16/21] Merge branch 'master' into kerberos_auth --- .pipelines/TestSql2017.yml | 4 +- appveyor.yml | 1 + bulkimport_example_test.go | 14 +- datetimeoffset_example_test.go | 13 +- go.mod | 1 + go.sum | 2 + lastinsertid_example_test.go | 14 +- messages_benchmark_test.go | 63 ++++++++ messages_example_test.go | 76 ++++++++++ mssql.go | 224 +++++++++++++++++++++++++++- mssql_go19.go | 6 + newconnector_example_test.go | 29 ++-- queries_go19_test.go | 263 +++++++++++++++++++++++++++++++++ queries_test.go | 2 +- tds.go | 8 + tds_go110_test.go | 3 +- tds_login_test.go | 20 ++- tds_test.go | 36 +++-- token.go | 54 ++++++- tvp_example_test.go | 14 +- tvp_go19.go | 3 +- tvp_go19_db_test.go | 187 +++++++++++++++++++++++ 22 files changed, 951 insertions(+), 86 deletions(-) create mode 100644 messages_benchmark_test.go create mode 100644 messages_example_test.go diff --git a/.pipelines/TestSql2017.yml b/.pipelines/TestSql2017.yml index 2a052e96..046e3e98 100644 --- a/.pipelines/TestSql2017.yml +++ b/.pipelines/TestSql2017.yml @@ -58,9 +58,9 @@ steps: workingDirectory: '$(Build.SourcesDirectory)' displayName: 'run tests' env: - SQLSERVER_DSN: 'server=.;user id=sa;password=$(TESTPASSWORD)' + SQLPASSWORD: $(SQLPASSWORD) AZURESERVER_DSN: $(AZURESERVER_DSN) - + SQLSERVER_DSN: $(SQLSERVER_DSN) continueOnError: true - task: PublishTestResults@2 displayName: "Publish junit-style results" diff --git a/appveyor.yml b/appveyor.yml index ecb893a3..c03f375c 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -49,6 +49,7 @@ install: - go version - go env - go get -u github.com/golang-sql/civil + - go get -u github.com/golang-sql/sqlexp build_script: - go build diff --git a/bulkimport_example_test.go b/bulkimport_example_test.go index 54d3dc2d..4fafbc99 100644 --- a/bulkimport_example_test.go +++ b/bulkimport_example_test.go @@ -1,11 +1,10 @@ +//go:build go1.10 // +build go1.10 package mssql_test import ( "database/sql" - "flag" - "fmt" "log" "strings" "unicode/utf8" @@ -32,19 +31,8 @@ const ( // This example shows how to perform bulk imports func ExampleCopyIn() { - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } connString := makeConnURL().String() - if *debug { - fmt.Printf(" connString:%s\n", connString) - } db, err := sql.Open("sqlserver", connString) if err != nil { diff --git a/datetimeoffset_example_test.go b/datetimeoffset_example_test.go index fa3dffb3..ad419c41 100644 --- a/datetimeoffset_example_test.go +++ b/datetimeoffset_example_test.go @@ -1,10 +1,10 @@ +//go:build go1.10 // +build go1.10 package mssql_test import ( "database/sql" - "flag" "fmt" "log" "time" @@ -15,19 +15,8 @@ import ( // This example shows how to insert and retrieve date and time types data func ExampleDateTimeOffset() { - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } connString := makeConnURL().String() - if *debug { - fmt.Printf(" connString:%s\n", connString) - } db, err := sql.Open("sqlserver", connString) if err != nil { diff --git a/go.mod b/go.mod index 47efb02c..02f7666c 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe + github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 github.com/jcmturner/gokrb5/v8 v8.4.2 golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9 ) diff --git a/go.sum b/go.sum index 44d47e92..b0beeca8 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/jcmturner/gokrb5/v8 v8.4.2 h1:6ZIM6b/JJN0X8UM43ZOM6Z4SJzla+a/u7scXFJz github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188 h1:+eHOFJl1BaXrQxKX+T06f78590z4qA2ZzBTqahsKSE4= +github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188/go.mod h1:vXjM/+wXQnTPR4KqTKDgJukSZ6amVRtWMPEjE6sQoK8= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= diff --git a/lastinsertid_example_test.go b/lastinsertid_example_test.go index 9a82284f..260b44ec 100644 --- a/lastinsertid_example_test.go +++ b/lastinsertid_example_test.go @@ -1,29 +1,17 @@ +//go:build go1.10 // +build go1.10 package mssql_test import ( "database/sql" - "flag" - "fmt" "log" ) // This example shows the usage of Connector type func ExampleLastInsertId() { - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } connString := makeConnURL().String() - if *debug { - fmt.Printf(" connString:%s\n", connString) - } db, err := sql.Open("sqlserver", connString) if err != nil { diff --git a/messages_benchmark_test.go b/messages_benchmark_test.go new file mode 100644 index 00000000..b7425429 --- /dev/null +++ b/messages_benchmark_test.go @@ -0,0 +1,63 @@ +// +build go1.14 + +package mssql + +import ( + "testing" +) + +func BenchmarkMessageQueue(b *testing.B) { + conn, logger := open(b) + defer conn.Close() + defer logger.StopLogging() + + b.Run("BlockingQuery", func(b *testing.B) { + var errs, results float64 + for i := 0; i < b.N; i++ { + r, err := conn.Query(mixedQuery) + if err != nil { + b.Fatal(err.Error()) + } + defer r.Close() + active := true + first := true + for active { + active = r.Next() + if active && first { + results++ + } + first = false + if !active { + if r.Err() != nil { + b.Logf("r.Err:%v", r.Err()) + errs++ + } + active = r.NextResultSet() + if active { + first = true + } + } + } + } + b.ReportMetric(float64(0), "msgs/op") + b.ReportMetric(errs/float64(b.N), "errors/op") + b.ReportMetric(results/float64(b.N), "results/op") + }) + b.Run("NonblockingQuery", func(b *testing.B) { + var msgs, errs, results, rowcounts float64 + for i := 0; i < b.N; i++ { + m, e, r, rc := testMixedQuery(conn, b) + msgs += float64(m) + errs += float64(e) + results += float64(r) + rowcounts += float64(rc) + if r != 4 { + b.Fatalf("Got wrong results count: %d, expected 4", r) + } + } + b.ReportMetric(msgs/float64(b.N), "msgs/op") + b.ReportMetric(errs/float64(b.N), "errors/op") + b.ReportMetric(results/float64(b.N), "results/op") + b.ReportMetric(rowcounts/float64(b.N), "rowcounts/op") + }) +} diff --git a/messages_example_test.go b/messages_example_test.go new file mode 100644 index 00000000..37dc0e8f --- /dev/null +++ b/messages_example_test.go @@ -0,0 +1,76 @@ +//go:build go1.10 +// +build go1.10 + +package mssql_test + +import ( + "context" + "database/sql" + "fmt" + "log" + + mssql "github.com/denisenkom/go-mssqldb" + "github.com/golang-sql/sqlexp" +) + +const ( + msgQuery = `select 'name' as Name +PRINT N'This is a message' +select 199 +RAISERROR (N'Testing!' , 11, 1) +select 300 +` +) + +// This example shows the usage of sqlexp/Messages +func ExampleRows_usingmessages() { + + connString := makeConnURL().String() + + // Create a new connector object by calling NewConnector + connector, err := mssql.NewConnector(connString) + if err != nil { + log.Println(err) + return + } + + // Pass connector to sql.OpenDB to get a sql.DB object + db := sql.OpenDB(connector) + defer db.Close() + retmsg := &sqlexp.ReturnMessage{} + ctx := context.Background() + rows, err := db.QueryContext(ctx, msgQuery, retmsg) + if err != nil { + log.Fatalf("QueryContext failed: %v", err) + } + active := true + for active { + msg := retmsg.Message(ctx) + switch m := msg.(type) { + case sqlexp.MsgNotice: + fmt.Println(m.Message) + case sqlexp.MsgNext: + inresult := true + for inresult { + inresult = rows.Next() + if inresult { + cols, err := rows.Columns() + if err != nil { + log.Fatalf("Columns failed: %v", err) + } + fmt.Println(cols) + var d interface{} + if err = rows.Scan(&d); err == nil { + fmt.Println(d) + } + } + } + case sqlexp.MsgNextResultSet: + active = rows.NextResultSet() + case sqlexp.MsgError: + fmt.Println("Error:", m.Error) + case sqlexp.MsgRowsAffected: + fmt.Println("Rows affected:", m.Count) + } + } +} diff --git a/mssql.go b/mssql.go index c34f2a03..fbd44d1f 100644 --- a/mssql.go +++ b/mssql.go @@ -17,6 +17,7 @@ import ( "github.com/denisenkom/go-mssqldb/internal/querytext" "github.com/denisenkom/go-mssqldb/msdsn" + "github.com/golang-sql/sqlexp" ) // ReturnStatus may be used to return the return value from a proc. @@ -206,6 +207,7 @@ type Conn struct { type outputs struct { params map[string]interface{} returnStatus *ReturnStatus + msgq *sqlexp.ReturnMessage } // IsValid satisfies the driver.Validator interface. @@ -667,6 +669,11 @@ func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err e ctx, cancel := context.WithCancel(ctx) reader := startReading(s.c.sess, ctx, s.c.outs) s.c.clearOuts() + // For apps using a message queue, return right away and let Rowsq do all the work + if reader.outs.msgq != nil { + res = &Rowsq{stmt: s, reader: reader, cols: nil, cancel: cancel} + return res, nil + } // process metadata var cols []columnStruct loop: @@ -738,13 +745,13 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) { return &Result{s.c, reader.rowCount}, nil } +// Rows represents the non-experimental data/sql model for Query and QueryContext type Rows struct { stmt *Stmt cols []columnStruct reader *tokenProcessor nextCols []columnStruct - - cancel func() + cancel func() } func (rc *Rows) Close() error { @@ -772,6 +779,7 @@ func (rc *Rows) Close() error { } func (rc *Rows) Columns() (res []string) { + res = make([]string, len(rc.cols)) for i, col := range rc.cols { res[i] = col.ColName @@ -793,6 +801,7 @@ func (rc *Rows) Next(dest []driver.Value) error { return io.EOF } else { switch tokdata := tok.(type) { + // processQueryResponse may have delegated all the token reading to us case []columnStruct: rc.nextCols = tokdata return io.EOF @@ -1058,3 +1067,214 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive } return s.exec(ctx, list) } + +// Rowsq implements the sqlexp messages model for Query and QueryContext +// Theory: We could also implement the non-experimental model this way +type Rowsq struct { + stmt *Stmt + cols []columnStruct + reader *tokenProcessor + nextCols []columnStruct + cancel func() + requestDone bool + inResultSet bool +} + +func (rc *Rowsq) Close() error { + rc.cancel() + + for { + tok, err := rc.reader.nextToken() + if err == nil { + if tok == nil { + return nil + } else { + // continue consuming tokens + continue + } + } else { + if err == rc.reader.ctx.Err() { + return nil + } else { + return err + } + } + } +} + +// data/sql calls Columns during the app's call to Next +func (rc *Rowsq) Columns() (res []string) { + if rc.cols == nil { + scan: + for { + tok, err := rc.reader.nextToken() + if err == nil { + if rc.reader.sess.logFlags&logDebug != 0 { + rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Columns() token type:%v", reflect.TypeOf(tok))) + } + if tok == nil { + return []string{} + } else { + switch tokdata := tok.(type) { + case []columnStruct: + rc.cols = tokdata + rc.inResultSet = true + break scan + } + } + } + } + } + res = make([]string, len(rc.cols)) + for i, col := range rc.cols { + res[i] = col.ColName + } + return +} + +func (rc *Rowsq) Next(dest []driver.Value) error { + if !rc.stmt.c.connectionGood { + return driver.ErrBadConn + } + for { + tok, err := rc.reader.nextToken() + if rc.reader.sess.logFlags&logDebug != 0 { + rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("Next() token type:%v", reflect.TypeOf(tok))) + } + if err == nil { + if tok == nil { + return io.EOF + } else { + switch tokdata := tok.(type) { + case []interface{}: + for i := range dest { + dest[i] = tokdata[i] + } + return nil + case doneStruct: + if tokdata.Status&doneMore == 0 { + rc.requestDone = true + } + if tokdata.isError() { + e := rc.stmt.c.checkBadConn(rc.reader.ctx, tokdata.getError(), false) + switch e.(type) { + case Error: + // Ignore non-fatal server errors. Fatal errors are of type ServerError + default: + return e + } + } + if rc.inResultSet { + rc.inResultSet = false + return io.EOF + } + case ReturnStatus: + if rc.reader.outs.returnStatus != nil { + *rc.reader.outs.returnStatus = tokdata + } + } + } + + } else { + return rc.stmt.c.checkBadConn(rc.reader.ctx, err, false) + } + } +} + +// In Message Queue mode, we always claim another resultset could be on the way +// to avoid Rows being closed prematurely +func (rc *Rowsq) HasNextResultSet() bool { + return !rc.requestDone +} + +// Scans to the next set of columns in the stream +// Note that the caller may not have read all the rows in the prior set +func (rc *Rowsq) NextResultSet() error { + if rc.requestDone { + return io.EOF + } +scan: + for { + // we should have a columns token in the channel if we aren't at the end + tok, err := rc.reader.nextToken() + if rc.reader.sess.logFlags&logDebug != 0 { + rc.reader.sess.logger.Log(rc.reader.ctx, msdsn.LogDebug, fmt.Sprintf("NextResultSet() token type:%v", reflect.TypeOf(tok))) + } + + if err != nil { + return err + } + if tok == nil { + return io.EOF + } + switch tokdata := tok.(type) { + case []columnStruct: + rc.nextCols = tokdata + rc.inResultSet = true + break scan + case doneStruct: + if tokdata.Status&doneMore == 0 { + rc.nextCols = nil + rc.requestDone = true + break scan + } + } + } + rc.cols = rc.nextCols + rc.nextCols = nil + if rc.cols == nil { + return io.EOF + } + return nil +} + +// It should return +// the value type that can be used to scan types into. For example, the database +// column type "bigint" this should return "reflect.TypeOf(int64(0))". +func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type { + return makeGoLangScanType(r.cols[index].ti) +} + +// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the +// database system type name without the length. Type names should be uppercase. +// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", +// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", +// "TIMESTAMP". +func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string { + return makeGoLangTypeName(r.cols[index].ti) +} + +// RowsColumnTypeLength may be implemented by Rows. It should return the length +// of the column type if the column is a variable length type. If the column is +// not a variable length type ok should return false. +// If length is not limited other than system limits, it should return math.MaxInt64. +// The following are examples of returned values for various types: +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) +func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) { + return makeGoLangTypeLength(r.cols[index].ti) +} + +// It should return +// the precision and scale for decimal types. If not applicable, ok should be false. +// The following are examples of returned values for various types: +// decimal(38, 4) (38, 4, true) +// int (0, 0, false) +// decimal (math.MaxInt64, math.MaxInt64, true) +func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) { + return makeGoLangTypePrecisionScale(r.cols[index].ti) +} + +// The nullable value should +// be true if it is known the column may be null, or false if the column is known +// to be not nullable. +// If the column nullability is unknown, ok should be false. +func (r *Rowsq) ColumnTypeNullable(index int) (nullable, ok bool) { + nullable = r.cols[index].Flags&colFlagNullable != 0 + ok = true + return +} diff --git a/mssql_go19.go b/mssql_go19.go index e77eebba..508b03a0 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -10,6 +10,8 @@ import ( "reflect" "time" + "github.com/golang-sql/sqlexp" + // "github.com/cockroachdb/apd" "github.com/golang-sql/civil" ) @@ -114,6 +116,10 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { return driver.ErrRemoveArgument case TVP: return nil + case *sqlexp.ReturnMessage: + sqlexp.ReturnMessageInit(v) + c.outs.msgq = v + return driver.ErrRemoveArgument default: var err error nv.Value, err = convertInputParameter(nv.Value) diff --git a/newconnector_example_test.go b/newconnector_example_test.go index 613866bb..8dc74baa 100644 --- a/newconnector_example_test.go +++ b/newconnector_example_test.go @@ -1,3 +1,4 @@ +//go:build go1.10 // +build go1.10 package mssql_test @@ -18,7 +19,7 @@ var ( debug = flag.Bool("debug", false, "enable debugging") password = flag.String("password", "", "the database password") port *int = flag.Int("port", 1433, "the database port") - server = flag.String("server", "", "the database server") + server = flag.String("server", ".", "the database server") user = flag.String("user", "", "the database user") ) @@ -32,23 +33,31 @@ const ( ) func makeConnURL() *url.URL { + flag.Parse() + if *debug { + fmt.Printf(" password:%s\n", *password) + fmt.Printf(" port:%d\n", *port) + fmt.Printf(" server:%s\n", *server) + fmt.Printf(" user:%s\n", *user) + } + + params, err := mssql.GetConnParams() + if err == nil && params != nil { + return params.URL() + } + var userInfo *url.Userinfo + if *user != "" { + userInfo = url.UserPassword(*user, *password) + } return &url.URL{ Scheme: "sqlserver", Host: *server + ":" + strconv.Itoa(*port), - User: url.UserPassword(*user, *password), + User: userInfo, } } // This example shows the usage of Connector type func ExampleConnector() { - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } connString := makeConnURL().String() if *debug { diff --git a/queries_go19_test.go b/queries_go19_test.go index bbb75d74..12371094 100644 --- a/queries_go19_test.go +++ b/queries_go19_test.go @@ -7,9 +7,12 @@ import ( "context" "database/sql" "fmt" + "reflect" "regexp" "testing" "time" + + "github.com/golang-sql/sqlexp" ) func TestOutputParam(t *testing.T) { @@ -1105,3 +1108,263 @@ func TestClearReturnStatus(t *testing.T) { t.Errorf("expected status=42, got %d", rs) } } + +func TestMessageQueue(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + retmsg := &sqlexp.ReturnMessage{} + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+200000*time.Millisecond) + defer cancel() + rows, err := conn.QueryContext(ctx, "PRINT 'msg1'; select 100 as c; PRINT 'msg2'", retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer rows.Close() + active := true + + msgs := []interface{}{ + sqlexp.MsgNotice{Message: "msg1"}, + sqlexp.MsgNext{}, + sqlexp.MsgRowsAffected{Count: 1}, + sqlexp.MsgNotice{Message: "msg2"}, + sqlexp.MsgNextResultSet{}, + } + i := 0 + rsCount := 0 + for active { + msg := retmsg.Message(ctx) + if i >= len(msgs) { + t.Fatalf("Got extra message:%+v", msg) + } + t.Log(reflect.TypeOf(msg)) + if reflect.TypeOf(msgs[i]) != reflect.TypeOf(msg) { + t.Fatalf("Out of order or incorrect message at %d. Actual: %+v. Expected: %+v", i, reflect.TypeOf(msg), reflect.TypeOf(msgs[i])) + } + switch m := msg.(type) { + case sqlexp.MsgNotice: + t.Log(m.Message) + case sqlexp.MsgNextResultSet: + active = rows.NextResultSet() + if active { + t.Fatal("NextResultSet returned true") + } + rsCount++ + case sqlexp.MsgNext: + if !rows.Next() { + t.Fatal("rows.Next() returned false") + } + var c int + err = rows.Scan(&c) + if err != nil { + t.Fatalf("rows.Scan() failed: %s", err.Error()) + } + if c != 100 { + t.Fatalf("query returned wrong value: %d", c) + } + } + i++ + } +} + +func TestAdvanceResultSetAfterPartialRead(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + ctx := context.Background() + retmsg := &sqlexp.ReturnMessage{} + + rows, err := conn.QueryContext(ctx, "select top 2 object_id from sys.all_objects; print 'this is a message'; select 100 as Count; ", retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer rows.Close() + + rows.Next() + var g interface{} + err = rows.Scan(&g) + if err != nil { + t.Fatalf("Scan failed: %v", err) + } + next := rows.NextResultSet() + if !next { + t.Fatalf("NextResultSet returned false") + } + next = rows.Next() + if !next { + t.Fatalf("Next on the second result set returned false") + } + cols, err := rows.Columns() + if err != nil { + t.Fatalf("Columns() error: %s", err) + } + if cols[0] != "Count" { + t.Fatalf("Wrong column in second result:%s, expected Count", cols[0]) + } + var c int + err = rows.Scan(&c) + if err != nil { + t.Fatalf("Scan errored out on second result: %s", err) + } + if c != 100 { + t.Fatalf("Scan returned incorrect value on second result set: %d, expected 100", c) + } +} +func TestMessageQueueWithErrors(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + msgs, errs, results, rowcounts := testMixedQuery(conn, t) + if msgs != 1 { + t.Fatalf("Got %d messages, expected 1", msgs) + } + if errs != 1 { + t.Fatalf("Got %d errors, expected 1", errs) + } + if results != 4 { + t.Fatalf("Got %d results, expected 4", results) + } + if rowcounts != 4 { + t.Fatalf("Got %d row counts, expected 4", rowcounts) + } +} + +const mixedQuery = `select top 5 name from sys.system_columns +select getdate() +PRINT N'This is a message' +select 199 +RAISERROR (N'Testing!' , 11, 1) +select 300 +` + +func testMixedQuery(conn *sql.DB, b testing.TB) (msgs, errs, results, rowcounts int) { + ctx := context.Background() + retmsg := &sqlexp.ReturnMessage{} + r, err := conn.QueryContext(ctx, mixedQuery, retmsg) + if err != nil { + b.Fatal(err.Error()) + } + defer r.Close() + active := true + first := true + for active { + msg := retmsg.Message(ctx) + switch m := msg.(type) { + case sqlexp.MsgNotice: + b.Logf("MsgNotice:%s", m.Message) + msgs++ + case sqlexp.MsgNext: + b.Logf("MsgNext") + inresult := true + for inresult { + inresult = r.Next() + if first { + if !inresult { + b.Fatalf("First Next call returned false") + } + results++ + } + if inresult { + var d interface{} + err = r.Scan(&d) + if err != nil { + b.Fatalf("Scan failed:%v", err) + } + b.Logf("Row data:%v", d) + } + first = false + } + case sqlexp.MsgNextResultSet: + b.Log("MsgNextResultSet") + active = r.NextResultSet() + first = true + case sqlexp.MsgError: + b.Logf("MsgError:%v", m.Error) + errs++ + case sqlexp.MsgRowsAffected: + b.Logf("MsgRowsAffected:%d", m.Count) + rowcounts++ + } + } + return msgs, errs, results, rowcounts +} + +func TestTimeoutWithNoResults(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond) + defer cancel() + retmsg := &sqlexp.ReturnMessage{} + r, err := conn.QueryContext(ctx, `waitfor delay '00:00:15'; select 100`, retmsg) + if err != nil { + t.Fatal(err.Error()) + } + defer r.Close() + active := true + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %s", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNextResultSet: + active = r.NextResultSet() + if active { + t.Fatal("NextResultSet returned true") + } + case sqlexp.MsgNext: + if r.Next() { + t.Fatal("Got a successful Next even though the query should have timed out") + } + case sqlexp.MsgRowsAffected: + t.Fatalf("Got a MsgRowsAffected %d", m.Count) + } + } + if r.Err() != context.DeadlineExceeded { + t.Fatalf("Unexpected error: %v", r.Err()) + } + +} + +func TestCancelWithNoResults(t *testing.T) { + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + latency, _ := getLatency(t) + ctx, cancel := context.WithTimeout(context.Background(), latency+5000*time.Millisecond) + retmsg := &sqlexp.ReturnMessage{} + r, err := conn.QueryContext(ctx, `waitfor delay '00:00:15'; select 100`, retmsg) + if err != nil { + cancel() + t.Fatal(err.Error()) + } + defer r.Close() + time.Sleep(latency + 100*time.Millisecond) + cancel() + active := true + for active { + msg := retmsg.Message(ctx) + t.Logf("Got a message: %s", reflect.TypeOf(msg)) + switch m := msg.(type) { + case sqlexp.MsgNextResultSet: + active = r.NextResultSet() + if active { + t.Fatal("NextResultSet returned true") + } + case sqlexp.MsgNext: + if r.Next() { + t.Fatal("Got a successful Next even though the query should been cancelled") + } + case sqlexp.MsgRowsAffected: + t.Fatalf("Got a MsgRowsAffected %d", m.Count) + } + } + if r.Err() != context.Canceled { + t.Fatalf("Unexpected error: %v", r.Err()) + } +} diff --git a/queries_test.go b/queries_test.go index 438db695..2b1125df 100644 --- a/queries_test.go +++ b/queries_test.go @@ -2077,7 +2077,7 @@ func TestLoginTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), latency+(2*increment)) defer cancel() _, err := conn.ExecContext(ctx, "waitfor delay '00:00:03'") - t.Log("Got error ", err) + t.Logf("Got error type %v: %s ", reflect.TypeOf(err), err.Error()) if oe, ok := err.(*net.OpError); ok { if !oe.Timeout() { t.Fatalf("Got non-timeout error %s", oe.Error()) diff --git a/tds.go b/tds.go index deea9522..23a1d0f9 100644 --- a/tds.go +++ b/tds.go @@ -245,6 +245,13 @@ func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { return results, nil } +// OptionFlags1 +// http://msdn.microsoft.com/en-us/library/dd304019.aspx +const ( + fUseDB = 0x20 + fSetLang = 0x80 +) + // OptionFlags2 // http://msdn.microsoft.com/en-us/library/dd304019.aspx const ( @@ -981,6 +988,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont PacketSize: packetSize, Database: p.Database, OptionFlags2: fODBC, // to get unlimited TEXTSIZE + OptionFlags1: fUseDB | fSetLang, HostName: p.Workstation, ServerName: serverName, AppName: p.AppName, diff --git a/tds_go110_test.go b/tds_go110_test.go index 8f2b0d50..c36f29a9 100644 --- a/tds_go110_test.go +++ b/tds_go110_test.go @@ -7,10 +7,9 @@ import ( "testing" ) -func open(t *testing.T) (*sql.DB, *testLogger) { +func open(t testing.TB) (*sql.DB, *testLogger) { tl := testLogger{t: t} SetLogger(&tl) - checkConnStr(t) connector, err := NewConnector(makeConnStr(t).String()) if err != nil { t.Error("Open connection failed:", err.Error()) diff --git a/tds_login_test.go b/tds_login_test.go index 08f3e3e5..f7c4cb79 100644 --- a/tds_login_test.go +++ b/tds_login_test.go @@ -75,8 +75,16 @@ func testLoginSequenceServer(result chan error, conn net.Conn, expectedPackets, for bi := 0; bi < n; bi++ { if expectedBytes[bi+b] != packet[bi] { - err = fmt.Errorf("Client sent unexpected byte %02X != %02X at offset %d of packet %d", - packet[bi], expectedBytes[bi+b], bi+b, i) + suffix := "" + if bi > 0 { + suffix = fmt.Sprintf("Previous byte: %02X", packet[bi-1]) + } + if bi < n { + suffix = fmt.Sprintf("%s Next byte:%02X", suffix, packet[bi+1]) + } + err = fmt.Errorf("Client sent unexpected byte %02X != %02X at offset %d of packet %d. %s", + packet[bi], expectedBytes[bi+b], bi+b, i, suffix) + result <- err return } @@ -126,7 +134,7 @@ func TestLoginWithSQLServerAuth(t *testing.T) { "01 ff 00 00 00 00 00 00 00 00 00 00 00 00 00\n", " 10 01 00 b2 00 00 01 00 aa 00 00 00 04 00 00 74\n" + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + - "00 02 00 00 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "A0 02 00 00 00 00 00 00 00 00 00 00 5e 00 09 00\n" + "70 00 04 00 78 00 06 00 84 00 0a 00 98 00 09 00\n" + "00 00 00 00 aa 00 00 00 aa 00 00 00 aa 00 00 00\n" + "00 00 00 00 00 00 aa 00 00 00 aa 00 00 00 aa 00\n" + @@ -187,7 +195,7 @@ func TestLoginWithSecurityTokenAuth(t *testing.T) { "00 00 00 00 01\n", " 10 01 00 BB 00 00 01 00 B3 00 00 00 04 00 00 74\n" + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + - "00 02 00 10 00 00 00 00 00 00 00 00 5E 00 09 00\n" + + "A0 02 00 10 00 00 00 00 00 00 00 00 5E 00 09 00\n" + "70 00 00 00 70 00 00 00 70 00 0A 00 84 00 09 00\n" + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + @@ -250,7 +258,7 @@ func TestLoginWithADALUsernamePasswordAuth(t *testing.T) { "00 00 00 00 01\n", " 10 01 00 aa 00 00 01 00 a2 00 00 00 04 00 00 74\n" + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + - "00 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "A0 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + "70 00 00 00 70 00 00 00 70 00 0a 00 84 00 09 00\n" + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + @@ -324,7 +332,7 @@ func TestLoginWithADALManagedIdentityAuth(t *testing.T) { "00 00 00 00 01\n", " 10 01 00 aa 00 00 01 00 a2 00 00 00 04 00 00 74\n" + "00 10 00 00 00 00 00 00 00 00 00 00 00 00 00 00\n" + - "00 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + + "A0 02 00 10 00 00 00 00 00 00 00 00 5e 00 09 00\n" + "70 00 00 00 70 00 00 00 70 00 0a 00 84 00 09 00\n" + "96 00 04 00 96 00 00 00 96 00 00 00 96 00 00 00\n" + "00 00 00 00 00 00 96 00 00 00 96 00 00 00 96 00\n" + diff --git a/tds_test.go b/tds_test.go index bfce86c5..7d9af0de 100644 --- a/tds_test.go +++ b/tds_test.go @@ -202,25 +202,43 @@ func TestSendSqlBatch(t *testing.T) { // returns parsed connection parameters derived from // environment variables func testConnParams(t testing.TB) msdsn.Config { + params, err := GetConnParams() + if err != nil { + t.Fatal("unable to parse SQLSERVER_DSN or read .connstr", err) + } + if params == nil { + t.Skip("no database connection string") + return msdsn.Config{} + } + return *params +} + +// TestConnParams returns a connection configuration based on environment variables or the contents of a text file +// Set environment variable SQLSERVER_DSN to provide an entire connection string +// Set environment variables HOST and DATABASE from which a minimal config will be created. +// If HOST and DATABASE are set, you can optionally set INSTANCE, SQLUSER, and SQLPASSWORD as well +// If environment variables are not set, it will look in the working directory for a file named .connstr +// If the file exists it will use the first line of the file as the file as the DSN +func GetConnParams() (*msdsn.Config, error) { dsn := os.Getenv("SQLSERVER_DSN") const logFlags = 127 if len(dsn) > 0 { params, _, err := msdsn.Parse(dsn) if err != nil { - t.Fatal("unable to parse SQLSERVER_DSN", err) + return nil, err } params.LogFlags = logFlags - return params + return ¶ms, nil } if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 { - return msdsn.Config{ + return &msdsn.Config{ Host: os.Getenv("HOST"), Instance: os.Getenv("INSTANCE"), Database: os.Getenv("DATABASE"), User: os.Getenv("SQLUSER"), Password: os.Getenv("SQLPASSWORD"), LogFlags: logFlags, - } + }, nil } // try loading connection string from file f, err := os.Open(".connstr") @@ -228,17 +246,17 @@ func testConnParams(t testing.TB) msdsn.Config { rdr := bufio.NewReader(f) dsn, err := rdr.ReadString('\n') if err != io.EOF && err != nil { - t.Fatal(err) + return nil, err } params, _, err := msdsn.Parse(dsn) if err != nil { - t.Fatal("unable to parse connection string loaded from file", err) + return nil, err } params.LogFlags = logFlags - return params + return ¶ms, nil } - t.Skip("no database connection string") - return msdsn.Config{} + + return nil, nil } func checkConnStr(t testing.TB) { diff --git a/token.go b/token.go index 643a78ac..43039d3d 100644 --- a/token.go +++ b/token.go @@ -10,6 +10,7 @@ import ( "strconv" "github.com/denisenkom/go-mssqldb/msdsn" + "github.com/golang-sql/sqlexp" ) //go:generate go run golang.org/x/tools/cmd/stringer -type token @@ -108,6 +109,7 @@ func (d doneStruct) getError() Error { return Error{Message: "Request failed but didn't provide reason"} } err := d.errors[n-1] + // should this return the most severe error? err.All = make([]Error, n) copy(err.All, d.errors) return err @@ -643,6 +645,7 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { } func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs outputs) { + firstResult := true defer func() { if err := recover(); err != nil { if sess.logFlags&logErrors != 0 { @@ -692,30 +695,67 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS ch <- order case tokenDoneInProc: done := parseDoneInProc(sess.buf) + + ch <- done if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { - sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount)) + sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d rows affected)", done.RowCount)) + + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) + } + } + if done.Status&doneMore == 0 { + if outs.msgq != nil { + // For now we ignore ctx->Done errors that ReturnMessageEnqueue might return + // It's not clear how to handle them correctly here, and data/sql seems + // to set Rows.Err correctly when ctx expires already + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } + return } - ch <- done case tokenDone, tokenDoneProc: done := parseDone(sess.buf) done.errors = errs + if outs.msgq != nil { + errs = make([]Error, 0, 5) + } if sess.logFlags&logDebug != 0 { sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("got DONE or DONEPROC status=%d", done.Status)) } if done.Status&doneSrvError != 0 { ch <- ServerError{done.getError()} + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } return } if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount)) } ch <- done + if done.Status&doneCount != 0 { + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) + } + } if done.Status&doneMore == 0 { + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } return } case tokenColMetadata: columns = parseColMetadata72(sess.buf) ch <- columns + + if outs.msgq != nil { + if !firstResult { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) + } + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNext{}) + } + firstResult = false + case tokenRow: row := make([]interface{}, len(columns)) parseRow(sess.buf, columns, row) @@ -735,6 +775,9 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS if sess.logFlags&logErrors != 0 { sess.logger.Log(ctx, msdsn.LogErrors, err.Message) } + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: err}) + } case tokenInfo: info := parseInfo(sess.buf) if sess.logFlags&logDebug != 0 { @@ -743,6 +786,9 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS if sess.logFlags&logMessages != 0 { sess.logger.Log(ctx, msdsn.LogMessages, info.Message) } + if outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info.Message}) + } case tokenReturnValue: nv := parseReturnValue(sess.buf) if len(nv.Name) > 0 { @@ -854,6 +900,10 @@ func (t tokenProcessor) nextToken() (tokenStruct, error) { return nil, nil } case <-t.ctx.Done(): + // It seems the Message function on t.outs.msgq doesn't get the Done if it comes here instead + if t.outs.msgq != nil { + _ = sqlexp.ReturnMessageEnqueue(t.ctx, t.outs.msgq, sqlexp.MsgNextResultSet{}) + } if t.noAttn { return nil, t.ctx.Err() } diff --git a/tvp_example_test.go b/tvp_example_test.go index 99582155..c27f98e5 100644 --- a/tvp_example_test.go +++ b/tvp_example_test.go @@ -1,10 +1,10 @@ +//go:build go1.10 // +build go1.10 package mssql_test import ( "database/sql" - "flag" "fmt" "log" @@ -46,19 +46,7 @@ func ExampleTVP() { Currency string `json:"-"` } - flag.Parse() - - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) - } - connString := makeConnURL().String() - if *debug { - fmt.Printf(" connString:%s\n", connString) - } db, err := sql.Open("sqlserver", connString) if err != nil { diff --git a/tvp_go19.go b/tvp_go19.go index d3890af9..32485a1b 100644 --- a/tvp_go19.go +++ b/tvp_go19.go @@ -105,7 +105,8 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd elemKind := field.Kind() if elemKind == reflect.Ptr && valOf.IsNil() { switch tvpVal.(type) { - case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64, *int: + case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64, *int, + *uint8, *uint16, *uint32, *uint64, *uint: binary.Write(buf, binary.LittleEndian, uint8(0)) continue default: diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index 510a4a0d..a8626988 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -1,3 +1,4 @@ +//go:build go1.9 // +build go1.9 package mssql @@ -1161,3 +1162,189 @@ func TestTVPObject(t *testing.T) { }) } } + +// fix pointer uint in tvp https://github.com/denisenkom/go-mssqldb/issues/703 +func TestTVPUnsigned(t *testing.T) { + checkConnStr(t) + tl := testLogger{t: t} + defer tl.StopLogging() + SetLogger(&tl) + + c := makeConnStr(t).String() + db, err := sql.Open("sqlserver", c) + if err != nil { + t.Fatalf("failed to open driver sqlserver") + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sqltextcreatetable := ` + CREATE TYPE unsignedTvpTableTypes AS TABLE + ( + p_tinyint TINYINT, + p_tinyintNull TINYINT, + p_smallint SMALLINT, + p_smallintNull SMALLINT, + p_int INT, + p_intNull INT, + p_bigint BIGINT, + p_bigintNull BIGINT, + pInt INT, + pIntNull INT + ); ` + + sqltextdroptable := `DROP TYPE unsignedTvpTableTypes;` + + sqltextcreatesp := ` + CREATE PROCEDURE spwithtvpUnsigned + @param1 unsignedTvpTableTypes READONLY, + @param2 unsignedTvpTableTypes READONLY, + @param3 NVARCHAR(10) + AS + BEGIN + SET NOCOUNT ON; + SELECT * FROM @param1; + SELECT * FROM @param2; + SELECT @param3; + END;` + + type TvptableRow struct { + PTinyint uint8 `db:"p_tinyint"` + PTinyintNull *uint8 `db:"p_tinyintNull"` + PSmallint uint16 `db:"p_smallint"` + PSmallintNull *uint16 `db:"p_smallintNull"` + PInt uint32 `db:"p_int"` + PIntNull *uint32 `db:"p_intNull"` + PBigint uint64 `db:"p_bigint"` + PBigintNull *uint64 `db:"p_bigintNull"` + Pint uint `db:"pInt"` + PintNull *uint `db:"pIntNull"` + } + + sqltextdropsp := `DROP PROCEDURE spwithtvpUnsigned;` + + _, err = db.ExecContext(ctx, sqltextcreatetable) + if err != nil { + t.Fatal(err) + } + defer db.ExecContext(ctx, sqltextdroptable) + + _, err = db.ExecContext(ctx, sqltextcreatesp) + if err != nil { + t.Fatal(err) + } + defer db.ExecContext(ctx, sqltextdropsp) + i8 := uint8(1) + i16 := uint16(2) + i32 := uint32(3) + i64 := uint64(4) + i := uint(5) + param1 := []TvptableRow{ + { + PTinyint: i8, + PSmallint: i16, + PInt: i32, + PBigint: i64, + Pint: 355, + }, + { + PTinyint: 5, + PSmallint: 16000, + PInt: 20000000, + PBigint: 2000000020000000, + Pint: 455, + }, + { + PTinyintNull: &i8, + PSmallintNull: &i16, + PIntNull: &i32, + PBigintNull: &i64, + PintNull: &i, + }, + { + PTinyint: 5, + PSmallint: 16000, + PInt: 20000000, + PBigint: 2000000020000000, + PTinyintNull: &i8, + PSmallintNull: &i16, + PIntNull: &i32, + PBigintNull: &i64, + PintNull: &i, + }, + } + + tvpType := TVP{ + TypeName: "unsignedTvpTableTypes", + Value: param1, + } + tvpTypeEmpty := TVP{ + TypeName: "unsignedTvpTableTypes", + Value: []TvptableRow{}, + } + + rows, err := db.QueryContext(ctx, + "exec spwithtvpUnsigned @param1, @param2, @param3", + sql.Named("param1", tvpType), + sql.Named("param2", tvpTypeEmpty), + sql.Named("param3", "test"), + ) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + var result1 []TvptableRow + for rows.Next() { + var val TvptableRow + err := rows.Scan( + &val.PTinyint, + &val.PTinyintNull, + &val.PSmallint, + &val.PSmallintNull, + &val.PInt, + &val.PIntNull, + &val.PBigint, + &val.PBigintNull, + &val.Pint, + &val.PintNull, + ) + if err != nil { + t.Fatalf("scan failed with error: %s", err) + } + + result1 = append(result1, val) + } + + if !reflect.DeepEqual(param1, result1) { + t.Logf("expected: %+v", param1) + t.Logf("actual: %+v", result1) + t.Errorf("first resultset did not match param1") + } + + if !rows.NextResultSet() { + t.Errorf("second resultset did not exist") + } + + if rows.Next() { + t.Errorf("second resultset was not empty") + } + + if !rows.NextResultSet() { + t.Errorf("third resultset did not exist") + } + + if !rows.Next() { + t.Errorf("third resultset was empty") + } + + var result3 string + if err := rows.Scan(&result3); err != nil { + t.Errorf("error scanning third result set: %s", err) + } + if result3 != "test" { + t.Errorf("third result set had wrong value expected: %s actual: %s", "test", result3) + } +} From 8059af5a58d6b3427e46ab2bb3d07a32d73bbeb8 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Mon, 24 Jan 2022 13:58:39 +0530 Subject: [PATCH 17/21] renamed kerberos config variable --- msdsn/conn_str.go | 4 ++-- msdsn/conn_str_test.go | 11 ++--------- tds.go | 4 ++-- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index f9272a14..ba603616 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -45,7 +45,7 @@ const ( type Kerberos struct { // Kerberos configuration details - Krb5Conf *config.Config + Krb5Config *config.Config // Credential cache Cache *credentials.CCache @@ -208,7 +208,7 @@ func Parse(dsn string) (Config, map[string]string, error) { if ok { p.Kerberos = &Kerberos{} var err error - p.Kerberos.Krb5Conf, err = setupKerbConfig(krb5ConfFile) + p.Kerberos.Krb5Config, err = setupKerbConfig(krb5ConfFile) if err != nil { return p, params, fmt.Errorf("cannot read kerberos configuration file: %w", err) } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 5c58e876..8bc1f792 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -1,7 +1,6 @@ package msdsn import ( - "fmt" "io/ioutil" "os" "reflect" @@ -229,13 +228,11 @@ func TestValidConnectionStringKerberos(t *testing.T) { t.Errorf("Connection string %s should fail to parse with error %s", connStrings, err) } } - deleteFile(kerberosTestFile, t) + os.Remove(kerberosTestFile) } func createKrbFile(t *testing.T) string { - dir := os.TempDir() - fmt.Println(dir) - file, err := ioutil.TempFile(dir, "test-*.txt") + file, err := ioutil.TempFile("", "test-*.txt") if err != nil { t.Errorf("Failed to create a temp file") } @@ -244,7 +241,3 @@ func createKrbFile(t *testing.T) string { } return file.Name() } - -func deleteFile(filename string, t *testing.T) { - os.Remove(filename) -} diff --git a/tds.go b/tds.go index 23a1d0f9..6738dcec 100644 --- a/tds.go +++ b/tds.go @@ -1174,8 +1174,8 @@ initiate_connection: } var auth auth var authOk bool - if p.Kerberos.Krb5Conf != nil { - auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Kerberos.Krb5Conf, p.Kerberos.Keytab, p.Kerberos.Cache) + if p.Kerberos != nil && p.Kerberos.Krb5Config != nil { + auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Kerberos.Krb5Config, p.Kerberos.Keytab, p.Kerberos.Cache) } else { auth, authOk = getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) } From b4e96e3ebbf42d1a0f5fb1bf3a1c6e5002fc71a6 Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Tue, 25 Jan 2022 20:20:59 +0530 Subject: [PATCH 18/21] nil pointer fix --- msdsn/conn_str.go | 6 +++--- msdsn/conn_str_test.go | 7 +++---- tds.go | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index ba603616..07e36634 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -45,7 +45,7 @@ const ( type Kerberos struct { // Kerberos configuration details - Krb5Config *config.Config + Config *config.Config // Credential cache Cache *credentials.CCache @@ -132,6 +132,7 @@ var skipSetup = errors.New("skip setting up TLS") func Parse(dsn string) (Config, map[string]string, error) { p := Config{} + p.Kerberos = &Kerberos{} var params map[string]string if strings.HasPrefix(dsn, "odbc:") { @@ -206,9 +207,8 @@ func Parse(dsn string) (Config, map[string]string, error) { krb5ConfFile, ok := params["krb5conffile"] if ok { - p.Kerberos = &Kerberos{} var err error - p.Kerberos.Krb5Config, err = setupKerbConfig(krb5ConfFile) + p.Kerberos.Config, err = setupKerbConfig(krb5ConfFile) if err != nil { return p, params, fmt.Errorf("cannot read kerberos configuration file: %w", err) } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 8bc1f792..c20af694 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -210,13 +210,13 @@ func TestInvalidConnectionStringKerberos(t *testing.T) { _, _, err := Parse(connStr) if err == nil { t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr) - continue } } } func TestValidConnectionStringKerberos(t *testing.T) { kerberosTestFile := createKrbFile(t) + defer os.Remove(kerberosTestFile) connStrings := []string{ "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=" + kerberosTestFile + ";keytabfile=" + kerberosTestFile, "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=" + kerberosTestFile + ";krbcache=" + kerberosTestFile, @@ -228,16 +228,15 @@ func TestValidConnectionStringKerberos(t *testing.T) { t.Errorf("Connection string %s should fail to parse with error %s", connStrings, err) } } - os.Remove(kerberosTestFile) } func createKrbFile(t *testing.T) string { file, err := ioutil.TempFile("", "test-*.txt") if err != nil { - t.Errorf("Failed to create a temp file") + t.Fatalf("Failed to create a temp file:%v",err) } if _, err := file.Write([]byte("This is a test file\n")); err != nil { - t.Errorf("Failed to write file") + t.Fatalf("Failed to write file:%v",err) } return file.Name() } diff --git a/tds.go b/tds.go index 6738dcec..f53742ec 100644 --- a/tds.go +++ b/tds.go @@ -1174,8 +1174,8 @@ initiate_connection: } var auth auth var authOk bool - if p.Kerberos != nil && p.Kerberos.Krb5Config != nil { - auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Kerberos.Krb5Config, p.Kerberos.Keytab, p.Kerberos.Cache) + if p.Kerberos != nil && p.Kerberos.Config != nil { + auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Kerberos.Config, p.Kerberos.Keytab, p.Kerberos.Cache) } else { auth, authOk = getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) } From dc1a81605d2836afde5543f644859985f9cef59a Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Tue, 8 Feb 2022 11:15:53 +0530 Subject: [PATCH 19/21] removed commented code --- kerbauth_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/kerbauth_test.go b/kerbauth_test.go index 0a57af0a..4b7360ea 100644 --- a/kerbauth_test.go +++ b/kerbauth_test.go @@ -104,7 +104,6 @@ func TestInitialBytes(t *testing.T) { t.Errorf("Failed to get Initial bytes") } - //krbObj.initkrbwithkeytab = true _, err = krbObj.InitialBytes() if err == nil { t.Errorf("Failed to get Initial bytes") From 73ae20f6fdc5ba69ff58ab7361940f093c25a08d Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Tue, 8 Feb 2022 13:49:59 +0530 Subject: [PATCH 20/21] removed unused field --- kerbauth.go | 4 +--- kerbauth_test.go | 20 +++++++------------- tds.go | 2 +- 3 files changed, 9 insertions(+), 17 deletions(-) diff --git a/kerbauth.go b/kerbauth.go index 475c944b..f0276ebf 100644 --- a/kerbauth.go +++ b/kerbauth.go @@ -16,7 +16,6 @@ type krb5Auth struct { username string realm string serverSPN string - password string port uint64 krb5Config *config.Config krbKeytab *keytab.Keytab @@ -25,7 +24,7 @@ type krb5Auth struct { state krb5ClientState } -func getKRB5Auth(user, password, serverSPN string, krb5Conf *config.Config, keytabContent *keytab.Keytab, cacheContent *credentials.CCache) (auth, bool) { +func getKRB5Auth(user, serverSPN string, krb5Conf *config.Config, keytabContent *keytab.Keytab, cacheContent *credentials.CCache) (auth, bool) { var port uint64 var realm, serviceStr string var err error @@ -73,7 +72,6 @@ func getKRB5Auth(user, password, serverSPN string, krb5Conf *config.Config, keyt krb5Config: krb5Conf, krbKeytab: keytabContent, krbCache: cacheContent, - password: password, }, true } diff --git a/kerbauth_test.go b/kerbauth_test.go index 4b7360ea..7d7f59e1 100644 --- a/kerbauth_test.go +++ b/kerbauth_test.go @@ -15,11 +15,10 @@ func TestGetKRB5Auth(t *testing.T) { krbKeytab := &keytab.Keytab{} krbCache := &credentials.CCache{} - got, _ := getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:1433", krbConf, krbKeytab, krbCache) + got, _ := getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", krbConf, krbKeytab, krbCache) keytab := &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", - password: "", port: 1433, krb5Config: krbConf, krbKeytab: krbKeytab, @@ -31,11 +30,10 @@ func TestGetKRB5Auth(t *testing.T) { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - got, _ = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:1433", krbConf, krbKeytab, krbCache) + got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433", krbConf, krbKeytab, krbCache) keytab = &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", - password: "", port: 1433, krb5Config: krbConf, krbKeytab: krbKeytab, @@ -47,16 +45,15 @@ func TestGetKRB5Auth(t *testing.T) { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - _, val := getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com", krbConf, krbKeytab, krbCache) + _, val := getKRB5Auth("", "MSSQLSvc/mssql.domain.com", krbConf, krbKeytab, krbCache) if val { t.Errorf("Failed to get correct krb5Auth object: no port defined") } - got, _ = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", krbConf, krbKeytab, krbCache) + got, _ = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@DOMAIN.COM", krbConf, krbKeytab, krbCache) keytab = &krb5Auth{username: "", realm: "DOMAIN.COM", serverSPN: "MSSQLSvc/mssql.domain.com:1433", - password: "", port: 1433, krb5Config: krbConf, krbKeytab: krbKeytab, @@ -68,17 +65,17 @@ func TestGetKRB5Auth(t *testing.T) { t.Errorf("Failed to get correct krb5Auth object\nExpected:%v\nRecieved:%v", keytab, got) } - _, val = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", krbConf, krbKeytab, krbCache) + _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:1433@domain.com@test", krbConf, krbKeytab, krbCache) if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect serverSPN name") } - _, val = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:port@domain.com", krbConf, krbKeytab, krbCache) + _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:port@domain.com", krbConf, krbKeytab, krbCache) if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect port") } - _, val = getKRB5Auth("", "", "MSSQLSvc/mssql.domain.com:port", krbConf, krbKeytab, krbCache) + _, val = getKRB5Auth("", "MSSQLSvc/mssql.domain.com:port", krbConf, krbKeytab, krbCache) if val { t.Errorf("Failed to get correct krb5Auth object due to incorrect port") } @@ -91,7 +88,6 @@ func TestInitialBytes(t *testing.T) { krbObj := &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", - password: "", port: 1433, krb5Config: krbConf, krbKeytab: krbKeytab, @@ -119,7 +115,6 @@ func TestNextBytes(t *testing.T) { var krbObj auth = &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", - password: "", port: 1433, krb5Config: krbConf, krbKeytab: krbKeytab, @@ -143,7 +138,6 @@ func TestFree(t *testing.T) { var krbObj auth = &krb5Auth{username: "", realm: "domain.com", serverSPN: "MSSQLSvc/mssql.domain.com:1433", - password: "", port: 1433, krb5Config: krbConf, krbKeytab: krbKeytab, diff --git a/tds.go b/tds.go index f53742ec..e28485ef 100644 --- a/tds.go +++ b/tds.go @@ -1175,7 +1175,7 @@ initiate_connection: var auth auth var authOk bool if p.Kerberos != nil && p.Kerberos.Config != nil { - auth, authOk = getKRB5Auth(p.User, p.Password, p.ServerSPN, p.Kerberos.Config, p.Kerberos.Keytab, p.Kerberos.Cache) + auth, authOk = getKRB5Auth(p.User, p.ServerSPN, p.Kerberos.Config, p.Kerberos.Keytab, p.Kerberos.Cache) } else { auth, authOk = getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) } From bf01def2f8f702be4a5a180411b3178ae1ce0f3e Mon Sep 17 00:00:00 2001 From: chandanjainn Date: Thu, 10 Feb 2022 11:39:21 +0530 Subject: [PATCH 21/21] code formatting --- msdsn/conn_str_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index c20af694..c3040d0a 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -55,7 +55,6 @@ func TestValidConnectionString(t *testing.T) { {"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p Config) bool { return p.Host == "server" && p.Instance == "instance" && p.User == "tester" && p.Password == "pwd" }}, - {"server=.", func(p Config) bool { return p.Host == "localhost" }}, {"server=(local)", func(p Config) bool { return p.Host == "localhost" }}, {"ServerSPN=serverspn;Workstation ID=workstid", func(p Config) bool { return p.ServerSPN == "serverspn" && p.Workstation == "workstid" }}, @@ -221,7 +220,6 @@ func TestValidConnectionStringKerberos(t *testing.T) { "server=server;user id=user;port=1345;realm=domain;trustservercertificate=true;krb5conffile=" + kerberosTestFile + ";keytabfile=" + kerberosTestFile, "server=server;port=1345;realm=domain;trustservercertificate=true;krb5conffile=" + kerberosTestFile + ";krbcache=" + kerberosTestFile, } - for _, connStr := range connStrings { _, _, err := Parse(connStr) if err == nil {