diff --git a/bind.go b/bind.go index 7a9ece4e..1ad53823 100644 --- a/bind.go +++ b/bind.go @@ -2,6 +2,7 @@ package ldap import ( "errors" + "fmt" "gopkg.in/asn1-ber.v1" ) @@ -90,7 +91,11 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children { - result.Controls = append(result.Controls, DecodeControl(child)) + decodedChild, err := DecodeControl(child) + if err != nil { + return nil, fmt.Errorf("failed to decode child control: %s", err) + } + result.Controls = append(result.Controls, decodedChild) } } diff --git a/compare.go b/compare.go index 82dca33c..c76fdb9e 100644 --- a/compare.go +++ b/compare.go @@ -77,5 +77,5 @@ func (l *Conn) Compare(dn, attribute, value string) (bool, error) { return false, NewError(resultCode, errors.New(resultDescription)) } } - return false, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag) + return false, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag) } diff --git a/control.go b/control.go index b7f181f2..bc946fae 100644 --- a/control.go +++ b/control.go @@ -249,7 +249,7 @@ func FindControl(controls []Control, controlType string) Control { } // DecodeControl returns a control read from the given packet, or nil if no recognized control can be made -func DecodeControl(packet *ber.Packet) Control { +func DecodeControl(packet *ber.Packet) (Control, error) { var ( ControlType = "" Criticality = false @@ -259,7 +259,7 @@ func DecodeControl(packet *ber.Packet) Control { switch len(packet.Children) { case 0: // at least one child is required for control type - return nil + return nil, fmt.Errorf("at least one child is required for control type") case 1: // just type, no criticality or value @@ -292,17 +292,20 @@ func DecodeControl(packet *ber.Packet) Control { default: // more than 3 children is invalid - return nil + return nil, fmt.Errorf("more than 3 children is invalid for controls") } switch ControlType { case ControlTypeManageDsaIT: - return NewControlManageDsaIT(Criticality) + return NewControlManageDsaIT(Criticality), nil case ControlTypePaging: value.Description += " (Paging)" c := new(ControlPaging) if value.Value != nil { - valueChildren := ber.DecodePacket(value.Data.Bytes()) + valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to decode data bytes: %s", err) + } value.Data.Truncate(0) value.Value = nil value.AppendChild(valueChildren) @@ -314,12 +317,15 @@ func DecodeControl(packet *ber.Packet) Control { c.PagingSize = uint32(value.Children[0].Value.(int64)) c.Cookie = value.Children[1].Data.Bytes() value.Children[1].Value = c.Cookie - return c + return c, nil case ControlTypeBeheraPasswordPolicy: value.Description += " (Password Policy - Behera)" c := NewControlBeheraPasswordPolicy() if value.Value != nil { - valueChildren := ber.DecodePacket(value.Data.Bytes()) + valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to decode data bytes: %s", err) + } value.Data.Truncate(0) value.Value = nil value.AppendChild(valueChildren) @@ -331,7 +337,10 @@ func DecodeControl(packet *ber.Packet) Control { if child.Tag == 0 { //Warning warningPacket := child.Children[0] - packet := ber.DecodePacket(warningPacket.Data.Bytes()) + packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to decode data bytes: %s", err) + } val, ok := packet.Value.(int64) if ok { if warningPacket.Tag == 0 { @@ -346,7 +355,10 @@ func DecodeControl(packet *ber.Packet) Control { } } else if child.Tag == 1 { // Error - packet := ber.DecodePacket(child.Data.Bytes()) + packet, err := ber.DecodePacketErr(child.Data.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to decode data bytes: %s", err) + } val, ok := packet.Value.(int8) if !ok { // what to do? @@ -357,22 +369,22 @@ func DecodeControl(packet *ber.Packet) Control { c.ErrorString = BeheraPasswordPolicyErrorMap[c.Error] } } - return c + return c, nil case ControlTypeVChuPasswordMustChange: c := &ControlVChuPasswordMustChange{MustChange: true} - return c + return c, nil case ControlTypeVChuPasswordWarning: c := &ControlVChuPasswordWarning{Expire: -1} expireStr := ber.DecodeString(value.Data.Bytes()) expire, err := strconv.ParseInt(expireStr, 10, 64) if err != nil { - return nil + return nil, fmt.Errorf("failed to parse value as int: %s", err) } c.Expire = expire value.Value = c.Expire - return c + return c, nil default: c := new(ControlString) c.ControlType = ControlType @@ -380,7 +392,7 @@ func DecodeControl(packet *ber.Packet) Control { if value != nil { c.ControlValue = value.Value.(string) } - return c + return c, nil } } diff --git a/control_test.go b/control_test.go index 11527463..4c7637d1 100644 --- a/control_test.go +++ b/control_test.go @@ -39,7 +39,10 @@ func runControlTest(t *testing.T, originalControl Control) { encodedBytes := encodedPacket.Bytes() // Decode directly from the encoded packet (ensures Value is correct) - fromPacket := DecodeControl(encodedPacket) + fromPacket, err := DecodeControl(encodedPacket) + if err != nil { + t.Errorf("%sdecoding encoded bytes control failed: %s", header, err) + } if !bytes.Equal(encodedBytes, fromPacket.Encode().Bytes()) { t.Errorf("%sround-trip from encoded packet failed", header) } @@ -48,7 +51,14 @@ func runControlTest(t *testing.T, originalControl Control) { } // Decode from the wire bytes (ensures ber-encoding is correct) - fromBytes := DecodeControl(ber.DecodePacket(encodedBytes)) + pkt, err := ber.DecodePacketErr(encodedBytes) + if err != nil { + t.Errorf("%sdecoding encoded bytes failed: %s", header, err) + } + fromBytes, err := DecodeControl(pkt) + if err != nil { + t.Errorf("%sdecoding control failed: %s", header, err) + } if !bytes.Equal(encodedBytes, fromBytes.Encode().Bytes()) { t.Errorf("%sround-trip from encoded bytes failed", header) } diff --git a/dn.go b/dn.go index 1ee9a1b9..e10a21b7 100644 --- a/dn.go +++ b/dn.go @@ -100,15 +100,15 @@ func ParseDN(str string) (*DN, error) { } // Not a special character, assume hex encoded octet if len(str) == i+1 { - return nil, errors.New("Got corrupted escaped character") + return nil, errors.New("got corrupted escaped character") } dst := []byte{0} n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2])) if err != nil { - return nil, fmt.Errorf("Failed to decode escaped character: %s", err) + return nil, fmt.Errorf("failed to decode escaped character: %s", err) } else if n != 1 { - return nil, fmt.Errorf("Expected 1 byte when un-escaping, got %d", n) + return nil, fmt.Errorf("expected 1 byte when un-escaping, got %d", n) } buffer.WriteByte(dst[0]) i++ @@ -131,9 +131,12 @@ func ParseDN(str string) (*DN, error) { } rawBER, err := enchex.DecodeString(data) if err != nil { - return nil, fmt.Errorf("Failed to decode BER encoding: %s", err) + return nil, fmt.Errorf("failed to decode BER encoding: %s", err) + } + packet, err := ber.DecodePacketErr(rawBER) + if err != nil { + return nil, fmt.Errorf("failed to decode BER packet: %s", err) } - packet := ber.DecodePacket(rawBER) buffer.WriteString(packet.Data.String()) i += len(data) - 1 } diff --git a/dn_test.go b/dn_test.go index af5fc146..7e631691 100644 --- a/dn_test.go +++ b/dn_test.go @@ -76,10 +76,10 @@ func TestSuccessfulDNParsing(t *testing.T) { func TestErrorDNParsing(t *testing.T) { testcases := map[string]string{ "*": "DN ended with incomplete type, value pair", - "cn=Jim\\0Test": "Failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'", - "cn=Jim\\0": "Got corrupted escaped character", + "cn=Jim\\0Test": "failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'", + "cn=Jim\\0": "got corrupted escaped character", "DC=example,=net": "DN ended with incomplete type, value pair", - "1=#0402486": "Failed to decode BER encoding: encoding/hex: odd length hex string", + "1=#0402486": "failed to decode BER encoding: encoding/hex: odd length hex string", "test,DC=example,DC=com": "incomplete type, value pair", "=test,DC=example,DC=com": "incomplete type, value pair", } diff --git a/ldap.go b/ldap.go index fe774b54..079f9fe0 100644 --- a/ldap.go +++ b/ldap.go @@ -2,6 +2,7 @@ package ldap import ( "errors" + "fmt" "io/ioutil" "os" @@ -97,13 +98,13 @@ func addLDAPDescriptions(packet *ber.Packet) (err error) { switch application { case ApplicationBindRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationBindResponse: - addDefaultLDAPResponseDescriptions(packet) + err = addDefaultLDAPResponseDescriptions(packet) case ApplicationUnbindRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationSearchRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationSearchResultEntry: packet.Children[1].Children[0].Description = "Object Name" packet.Children[1].Children[1].Description = "Attributes" @@ -116,37 +117,37 @@ func addLDAPDescriptions(packet *ber.Packet) (err error) { } } if len(packet.Children) == 3 { - addControlDescriptions(packet.Children[2]) + err = addControlDescriptions(packet.Children[2]) } case ApplicationSearchResultDone: - addDefaultLDAPResponseDescriptions(packet) + err = addDefaultLDAPResponseDescriptions(packet) case ApplicationModifyRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationModifyResponse: case ApplicationAddRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationAddResponse: case ApplicationDelRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationDelResponse: case ApplicationModifyDNRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationModifyDNResponse: case ApplicationCompareRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationCompareResponse: case ApplicationAbandonRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationSearchResultReference: case ApplicationExtendedRequest: - addRequestDescriptions(packet) + err = addRequestDescriptions(packet) case ApplicationExtendedResponse: } - return nil + return err } -func addControlDescriptions(packet *ber.Packet) { +func addControlDescriptions(packet *ber.Packet) error { packet.Description = "Controls" for _, child := range packet.Children { var value *ber.Packet @@ -155,7 +156,7 @@ func addControlDescriptions(packet *ber.Packet) { switch len(child.Children) { case 0: // at least one child is required for control type - continue + return fmt.Errorf("at least one child is required for control type") case 1: // just type, no criticality or value @@ -184,8 +185,9 @@ func addControlDescriptions(packet *ber.Packet) { default: // more than 3 children is invalid - continue + return fmt.Errorf("more than 3 children for control packet found") } + if value == nil { continue } @@ -193,7 +195,10 @@ func addControlDescriptions(packet *ber.Packet) { case ControlTypePaging: value.Description += " (Paging)" if value.Value != nil { - valueChildren := ber.DecodePacket(value.Data.Bytes()) + valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) + if err != nil { + return fmt.Errorf("failed to decode data bytes: %s", err) + } value.Data.Truncate(0) value.Value = nil valueChildren.Children[1].Value = valueChildren.Children[1].Data.Bytes() @@ -206,7 +211,10 @@ func addControlDescriptions(packet *ber.Packet) { case ControlTypeBeheraPasswordPolicy: value.Description += " (Password Policy - Behera Draft)" if value.Value != nil { - valueChildren := ber.DecodePacket(value.Data.Bytes()) + valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) + if err != nil { + return fmt.Errorf("failed to decode data bytes: %s", err) + } value.Data.Truncate(0) value.Value = nil value.AppendChild(valueChildren) @@ -216,7 +224,10 @@ func addControlDescriptions(packet *ber.Packet) { if child.Tag == 0 { //Warning warningPacket := child.Children[0] - packet := ber.DecodePacket(warningPacket.Data.Bytes()) + packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes()) + if err != nil { + return fmt.Errorf("failed to decode data bytes: %s", err) + } val, ok := packet.Value.(int64) if ok { if warningPacket.Tag == 0 { @@ -231,7 +242,10 @@ func addControlDescriptions(packet *ber.Packet) { } } else if child.Tag == 1 { // Error - packet := ber.DecodePacket(child.Data.Bytes()) + packet, err := ber.DecodePacketErr(child.Data.Bytes()) + if err != nil { + return fmt.Errorf("failed to decode data bytes: %s", err) + } val, ok := packet.Value.(int8) if !ok { val = -1 @@ -242,18 +256,20 @@ func addControlDescriptions(packet *ber.Packet) { } } } + return nil } -func addRequestDescriptions(packet *ber.Packet) { +func addRequestDescriptions(packet *ber.Packet) error { packet.Description = "LDAP Request" packet.Children[0].Description = "Message ID" packet.Children[1].Description = ApplicationMap[uint8(packet.Children[1].Tag)] if len(packet.Children) == 3 { - addControlDescriptions(packet.Children[2]) + return addControlDescriptions(packet.Children[2]) } + return nil } -func addDefaultLDAPResponseDescriptions(packet *ber.Packet) { +func addDefaultLDAPResponseDescriptions(packet *ber.Packet) error { resultCode, _ := getLDAPResultCode(packet) packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[resultCode] + ")" packet.Children[1].Children[1].Description = "Matched DN" @@ -262,8 +278,9 @@ func addDefaultLDAPResponseDescriptions(packet *ber.Packet) { packet.Children[1].Children[3].Description = "Referral" } if len(packet.Children) == 3 { - addControlDescriptions(packet.Children[2]) + return addControlDescriptions(packet.Children[2]) } + return nil } // DebugBinaryFile reads and prints packets from the given filename @@ -273,8 +290,13 @@ func DebugBinaryFile(fileName string) error { return NewError(ErrorDebugging, err) } ber.PrintBytes(os.Stdout, file, "") - packet := ber.DecodePacket(file) - addLDAPDescriptions(packet) + packet, err := ber.DecodePacketErr(file) + if err != nil { + return fmt.Errorf("failed to decode packet: %s", err) + } + if err := addLDAPDescriptions(packet); err != nil { + return err + } ber.PrintPacket(packet) return nil diff --git a/passwdmodify.go b/passwdmodify.go index 7d8246fd..20ae6ba3 100644 --- a/passwdmodify.go +++ b/passwdmodify.go @@ -129,7 +129,7 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa return nil, NewError(resultCode, errors.New(resultDescription)) } } else { - return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag)) + return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag)) } extendedResponse := packet.Children[1] diff --git a/search.go b/search.go index d1f0386f..8d3967da 100644 --- a/search.go +++ b/search.go @@ -309,10 +309,10 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) } else { castControl, ok := control.(*ControlPaging) if !ok { - return nil, fmt.Errorf("Expected paging control to be of type *ControlPaging, got %v", control) + return nil, fmt.Errorf("expected paging control to be of type *ControlPaging, got %v", control) } if castControl.PagingSize != pagingSize { - return nil, fmt.Errorf("Paging size given in search request (%d) conflicts with size given in search call (%d)", castControl.PagingSize, pagingSize) + return nil, fmt.Errorf("paging size given in search request (%d) conflicts with size given in search call (%d)", castControl.PagingSize, pagingSize) } pagingControl = castControl } @@ -433,7 +433,11 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children { - result.Controls = append(result.Controls, DecodeControl(child)) + decodedChild, err := DecodeControl(child) + if err != nil { + return nil, fmt.Errorf("failed to decode child control: %s", err) + } + result.Controls = append(result.Controls, decodedChild) } } foundSearchResultDone = true