diff --git a/add.go b/add.go index 1b49cf12..19bce1b7 100644 --- a/add.go +++ b/add.go @@ -106,9 +106,9 @@ func (l *Conn) Add(addRequest *AddRequest) error { } if packet.Children[1].Tag == ApplicationAddResponse { - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - return NewError(resultCode, errors.New(resultDescription)) + err := GetLDAPError(packet) + if err != nil { + return err } } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) diff --git a/bind.go b/bind.go index 1ad53823..59c3f5ef 100644 --- a/bind.go +++ b/bind.go @@ -79,7 +79,7 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu } if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { + if err = addLDAPDescriptions(packet); err != nil { return nil, err } ber.PrintPacket(packet) @@ -91,20 +91,16 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children { - decodedChild, err := DecodeControl(child) - if err != nil { - return nil, fmt.Errorf("failed to decode child control: %s", err) + decodedChild, decodeErr := DecodeControl(child) + if decodeErr != nil { + return nil, fmt.Errorf("failed to decode child control: %s", decodeErr) } result.Controls = append(result.Controls, decodedChild) } } - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - return result, NewError(resultCode, errors.New(resultDescription)) - } - - return result, nil + err = GetLDAPError(packet) + return result, err } // Bind performs a bind with the given username and password. diff --git a/compare.go b/compare.go index 5bf9bf70..5b5013cb 100644 --- a/compare.go +++ b/compare.go @@ -68,15 +68,15 @@ func (l *Conn) Compare(dn, attribute, value string) (bool, error) { } if packet.Children[1].Tag == ApplicationCompareResponse { - resultCode, resultDescription := getLDAPResultCode(packet) + err := GetLDAPError(packet) switch { - case resultCode == LDAPResultCompareTrue: + case IsErrorWithCode(err, LDAPResultCompareTrue): return true, nil - case resultCode == LDAPResultCompareFalse: + case IsErrorWithCode(err, LDAPResultCompareFalse): return false, nil default: - return false, NewError(resultCode, errors.New(resultDescription)) + return false, err } } return false, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag) diff --git a/conn.go b/conn.go index 8aece5a4..50d8c3f8 100644 --- a/conn.go +++ b/conn.go @@ -275,18 +275,18 @@ func (l *Conn) StartTLS(config *tls.Config) error { ber.PrintPacket(packet) } - if resultCode, message := getLDAPResultCode(packet); resultCode == LDAPResultSuccess { + if err := GetLDAPError(packet); err == nil { conn := tls.Client(l.conn, config) - if err := conn.Handshake(); err != nil { + if connErr := conn.Handshake(); connErr != nil { l.Close() - return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", err)) + return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr)) } l.isTLS = true l.conn = conn } else { - return NewError(resultCode, fmt.Errorf("ldap: cannot StartTLS (%s)", message)) + return err } go l.reader() diff --git a/del.go b/del.go index 685a0ffd..6f78beb1 100644 --- a/del.go +++ b/del.go @@ -71,9 +71,9 @@ func (l *Conn) Del(delRequest *DelRequest) error { } if packet.Children[1].Tag == ApplicationDelResponse { - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - return NewError(resultCode, errors.New(resultDescription)) + err := GetLDAPError(packet) + if err != nil { + return err } } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) diff --git a/error.go b/error.go index a2fa5515..50ed8ab3 100644 --- a/error.go +++ b/error.go @@ -176,35 +176,44 @@ var LDAPResultCodeMap = map[uint16]string{ ErrorEmptyPassword: "Empty password not allowed by the client", } -func getLDAPResultCode(packet *ber.Packet) (code uint16, description string) { - if packet == nil { - return ErrorUnexpectedResponse, "Empty packet" - } else if len(packet.Children) >= 2 { - response := packet.Children[1] - if response == nil { - return ErrorUnexpectedResponse, "Empty response in packet" - } - if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 { - // Children[1].Children[2] is the diagnosticMessage which is guaranteed to exist as seen here: https://tools.ietf.org/html/rfc4511#section-4.1.9 - return uint16(response.Children[0].Value.(int64)), response.Children[2].Value.(string) - } - } - - return ErrorNetwork, "Invalid packet format" -} - // Error holds LDAP error information type Error struct { // Err is the underlying error Err error // ResultCode is the LDAP error code ResultCode uint16 + // MatchedDN is the matchedDN returned if any + MatchedDN string } func (e *Error) Error() string { return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error()) } +// GetLDAPError creates an Error out of a BER packet representing a LDAPResult +// The return is an error object. It can be casted to a Error structure. +// This function returns nil if resultCode in the LDAPResult sequence is success(0). +func GetLDAPError(packet *ber.Packet) error { + if packet == nil { + return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty packet")} + } else if len(packet.Children) >= 2 { + response := packet.Children[1] + if response == nil { + return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet")} + } + if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 { + resultCode := uint16(response.Children[0].Value.(int64)) + if resultCode == 0 { // No error + return nil + } + return &Error{ResultCode: resultCode, MatchedDN: response.Children[1].Value.(string), + Err: fmt.Errorf(response.Children[2].Value.(string))} + } + } + + return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format")} +} + // NewError creates an LDAP error with the given code and underlying error func NewError(resultCode uint16, err error) error { return &Error{ResultCode: resultCode, Err: err} diff --git a/error_test.go b/error_test.go index e456431b..02a3eda3 100644 --- a/error_test.go +++ b/error_test.go @@ -13,9 +13,9 @@ import ( // TestNilPacket tests that nil packets don't cause a panic. func TestNilPacket(t *testing.T) { // Test for nil packet - code, _ := getLDAPResultCode(nil) - if code != ErrorUnexpectedResponse { - t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", code) + err := GetLDAPError(nil) + if !IsErrorWithCode(err, ErrorUnexpectedResponse) { + t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", err) } // Test for nil result @@ -24,10 +24,10 @@ func TestNilPacket(t *testing.T) { nil, // Can't be nil } pack := &ber.Packet{Children: kids} - code, _ = getLDAPResultCode(pack) + err = GetLDAPError(pack) - if code != ErrorUnexpectedResponse { - t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", code) + if !IsErrorWithCode(err, ErrorUnexpectedResponse) { + t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", err) } } diff --git a/ldap.go b/ldap.go index 079f9fe0..d7666676 100644 --- a/ldap.go +++ b/ldap.go @@ -270,9 +270,9 @@ func addRequestDescriptions(packet *ber.Packet) error { } func addDefaultLDAPResponseDescriptions(packet *ber.Packet) error { - resultCode, _ := getLDAPResultCode(packet) - packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[resultCode] + ")" - packet.Children[1].Children[1].Description = "Matched DN" + err := GetLDAPError(packet) + packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[err.(*Error).ResultCode] + ")" + packet.Children[1].Children[1].Description = "Matched DN (" + err.(*Error).MatchedDN + ")" packet.Children[1].Children[2].Description = "Error Message" if len(packet.Children[1].Children) > 3 { packet.Children[1].Children[3].Description = "Referral" diff --git a/ldap_test.go b/ldap_test.go index 58f8260e..cf827b0f 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -271,3 +271,32 @@ func TestCompare(t *testing.T) { fmt.Printf("TestCompare: -> %v\n", sr) } + +func TestMatchDNError(t *testing.T) { + fmt.Printf("TestMatchDNError: starting..\n") + + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Fatal(err.Error()) + } + defer l.Close() + + wrongBase := "ou=roups,dc=umich,dc=edu" + + searchRequest := NewSearchRequest( + wrongBase, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + _, err = l.Search(searchRequest) + + if err == nil { + t.Errorf("Expected Error, got nil") + return + } + + fmt.Printf("TestMatchDNError: err: %s\n", err.Error()) + +} diff --git a/moddn.go b/moddn.go index 7065b814..803279d2 100644 --- a/moddn.go +++ b/moddn.go @@ -91,9 +91,9 @@ func (l *Conn) ModifyDN(m *ModifyDNRequest) error { } if packet.Children[1].Tag == ApplicationModifyDNResponse { - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - return NewError(resultCode, errors.New(resultDescription)) + err := GetLDAPError(packet) + if err != nil { + return err } } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) diff --git a/modify.go b/modify.go index 67f27918..d83e6221 100644 --- a/modify.go +++ b/modify.go @@ -160,9 +160,9 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error { } if packet.Children[1].Tag == ApplicationModifyResponse { - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - return NewError(resultCode, errors.New(resultDescription)) + err := GetLDAPError(packet) + if err != nil { + return err } } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) diff --git a/passwdmodify.go b/passwdmodify.go index 8443babc..06bc21db 100644 --- a/passwdmodify.go +++ b/passwdmodify.go @@ -126,16 +126,16 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa } if packet.Children[1].Tag == ApplicationExtendedResponse { - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - if resultCode == LDAPResultReferral { + err := GetLDAPError(packet) + if err != nil { + if IsErrorWithCode(err, LDAPResultReferral) { for _, child := range packet.Children[1].Children { if child.Tag == 3 { result.Referral = child.Children[0].Value.(string) } } } - return result, NewError(resultCode, errors.New(resultDescription)) + return result, err } } else { return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag)) diff --git a/search.go b/search.go index 8d3967da..3aa6dac0 100644 --- a/search.go +++ b/search.go @@ -427,9 +427,9 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } result.Entries = append(result.Entries, entry) case 5: - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - return result, NewError(resultCode, errors.New(resultDescription)) + err := GetLDAPError(packet) + if err != nil { + return nil, err } if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children {