From 21d1415bcac35ef3e8e35abf1c0a6af46977dc3d Mon Sep 17 00:00:00 2001 From: Christopher Puschmann Date: Sat, 16 Sep 2023 13:51:47 +0200 Subject: [PATCH] chore: Mirror v3 to root directory (#468) --- bind.go | 2 +- conn.go | 5 +- conn_test.go | 26 +++--- control.go | 235 ++++++++++++++++++++++++++++++++++++++++++++---- control_test.go | 53 +++++++++++ search.go | 4 +- v3/conn.go | 2 +- 7 files changed, 292 insertions(+), 35 deletions(-) diff --git a/bind.go b/bind.go index 09914b27..a37f8e2c 100644 --- a/bind.go +++ b/bind.go @@ -614,7 +614,7 @@ func (l *Conn) GSSAPIBind(client GSSAPIClient, servicePrincipal, authzid string) // GSSAPIBindRequest performs the GSSAPI SASL bind using the provided GSSAPI client. func (l *Conn) GSSAPIBindRequest(client GSSAPIClient, req *GSSAPIBindRequest) error { - // nolint:errcheck + //nolint:errcheck defer client.DeleteSecContext() var err error diff --git a/conn.go b/conn.go index db4667ca..6668dcdf 100644 --- a/conn.go +++ b/conn.go @@ -288,10 +288,9 @@ func (l *Conn) Close() (err error) { l.chanMessage <- &messagePacket{Op: MessageQuit} timeoutCtx := context.Background() - requestTimeout := l.getTimeout() - if requestTimeout > 0 { + if l.getTimeout() > 0 { var cancelFunc context.CancelFunc - timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(requestTimeout)) + timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout())) defer cancelFunc() } select { diff --git a/conn_test.go b/conn_test.go index 0f6d7181..d0bfc0c5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -58,9 +58,7 @@ func TestUnresponsiveConnection(t *testing.T) { } } -// TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the -// message handler is blocked or inactive. -func TestInvalidStateCloseDeadlock(t *testing.T) { +func TestRequestTimeoutDeadlock(t *testing.T) { // The do-nothing server that accepts requests and does nothing ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) @@ -72,14 +70,23 @@ func TestInvalidStateCloseDeadlock(t *testing.T) { // Create an Ldap connection conn := NewConn(c, false) - conn.SetTimeout(time.Millisecond) + conn.Start() + // trigger a race condition on accessing request timeout + n := 3 + for i := 0; i < n; i++ { + go func() { + conn.SetTimeout(time.Millisecond) + }() + } // Attempt to close the connection when the message handler is // blocked or inactive conn.Close() } -func TestRequestTimeoutDeadlock(t *testing.T) { +// TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the +// message handler is blocked or inactive. +func TestInvalidStateCloseDeadlock(t *testing.T) { // The do-nothing server that accepts requests and does nothing ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) @@ -91,14 +98,7 @@ func TestRequestTimeoutDeadlock(t *testing.T) { // Create an Ldap connection conn := NewConn(c, false) - conn.Start() - // trigger a race condition on accessing request timeout - n := 3 - for i := 0; i < n; i++ { - go func() { - conn.SetTimeout(time.Millisecond) - }() - } + conn.SetTimeout(time.Millisecond) // Attempt to close the connection when the message handler is // blocked or inactive diff --git a/control.go b/control.go index e3ed683c..60453deb 100644 --- a/control.go +++ b/control.go @@ -24,6 +24,11 @@ const ( // ControlTypeSubtreeDelete - https://datatracker.ietf.org/doc/html/draft-armijo-ldap-treedelete-02 ControlTypeSubtreeDelete = "1.2.840.113556.1.4.805" + // ControlTypeServerSideSorting - https://www.ietf.org/rfc/rfc2891.txt + ControlTypeServerSideSorting = "1.2.840.113556.1.4.473" + // ControlTypeServerSideSorting - https://www.ietf.org/rfc/rfc2891.txt + ControlTypeServerSideSortingResult = "1.2.840.113556.1.4.474" + // ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528" // ControlTypeMicrosoftShowDeleted - https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx @@ -53,18 +58,20 @@ const ( // ControlTypeMap maps controls to text descriptions var ControlTypeMap = map[string]string{ - ControlTypePaging: "Paging", - ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft", - ControlTypeManageDsaIT: "Manage DSA IT", - ControlTypeSubtreeDelete: "Subtree Delete Control", - ControlTypeMicrosoftNotification: "Change Notification - Microsoft", - ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft", - ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft", - ControlTypeDirSync: "DirSync", - ControlTypeSyncRequest: "Sync Request", - ControlTypeSyncState: "Sync State", - ControlTypeSyncDone: "Sync Done", - ControlTypeSyncInfo: "Sync Info", + ControlTypePaging: "Paging", + ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft", + ControlTypeManageDsaIT: "Manage DSA IT", + ControlTypeSubtreeDelete: "Subtree Delete Control", + ControlTypeMicrosoftNotification: "Change Notification - Microsoft", + ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft", + ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft", + ControlTypeServerSideSorting: "Server Side Sorting Request - LDAP Control Extension for Server Side Sorting of Search Results (RFC2891)", + ControlTypeServerSideSortingResult: "Server Side Sorting Results - LDAP Control Extension for Server Side Sorting of Search Results (RFC2891)", + ControlTypeDirSync: "DirSync", + ControlTypeSyncRequest: "Sync Request", + ControlTypeSyncState: "Sync State", + ControlTypeSyncDone: "Sync Done", + ControlTypeSyncInfo: "Sync Info", } // Control defines an interface controls provide to encode and describe themselves @@ -521,6 +528,10 @@ func DecodeControl(packet *ber.Packet) (Control, error) { return NewControlMicrosoftServerLinkTTL(), nil case ControlTypeSubtreeDelete: return NewControlSubtreeDelete(), nil + case ControlTypeServerSideSorting: + return NewControlServerSideSorting(value) + case ControlTypeServerSideSortingResult: + return NewControlServerSideSortingResult(value) case ControlTypeDirSync: value.Description += " (DirSync)" return NewResponseControlDirSync(value) @@ -716,6 +727,193 @@ func (c *ControlDirSync) SetCookie(cookie []byte) { c.Cookie = cookie } +// ControlServerSideSorting + +type SortKey struct { + Reverse bool + AttributeType string + MatchingRule string +} + +type ControlServerSideSorting struct { + SortKeys []*SortKey +} + +func (c *ControlServerSideSorting) GetControlType() string { + return ControlTypeServerSideSorting +} + +func NewControlServerSideSorting(value *ber.Packet) (*ControlServerSideSorting, error) { + sortKeys := []*SortKey{} + + val := value.Children[1].Children + + if len(val) != 1 { + return nil, fmt.Errorf("no sequence value in packet") + } + + sequences := val[0].Children + + for i, sequence := range sequences { + sortKey := &SortKey{} + + if len(sequence.Children) < 2 { + return nil, fmt.Errorf("attributeType or matchingRule is missing from sequence %d", i) + } + + sortKey.AttributeType = sequence.Children[0].Value.(string) + sortKey.MatchingRule = sequence.Children[1].Value.(string) + + if len(sequence.Children) == 3 { + sortKey.Reverse = sequence.Children[2].Value.(bool) + } + + sortKeys = append(sortKeys, sortKey) + } + + return &ControlServerSideSorting{SortKeys: sortKeys}, nil +} + +func NewControlServerSideSortingWithSortKeys(sortKeys []*SortKey) *ControlServerSideSorting { + return &ControlServerSideSorting{SortKeys: sortKeys} +} + +func (c *ControlServerSideSorting) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + control := ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.GetControlType(), "Control Type") + + value := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value") + seqs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "SortKeyList") + + for _, f := range c.SortKeys { + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "") + + seq.AppendChild( + ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, f.AttributeType, "attributeType"), + ) + seq.AppendChild( + ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, f.MatchingRule, "orderingRule"), + ) + if f.Reverse { + seq.AppendChild( + ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, 1, f.Reverse, "reverseOrder"), + ) + } + + seqs.AppendChild(seq) + } + + value.AppendChild(seqs) + + packet.AppendChild(control) + packet.AppendChild(value) + + return packet +} + +func (c *ControlServerSideSorting) String() string { + return fmt.Sprintf( + "Control Type: %s (%q) Criticality:%t %+v", + "Server Side Sorting", + c.GetControlType(), + false, + c.SortKeys, + ) +} + +// ControlServerSideSortingResponse + +const ( + ControlServerSideSortingCodeSuccess ControlServerSideSortingCode = 0 + ControlServerSideSortingCodeOperationsError ControlServerSideSortingCode = 1 + ControlServerSideSortingCodeTimeLimitExceeded ControlServerSideSortingCode = 2 + ControlServerSideSortingCodeStrongAuthRequired ControlServerSideSortingCode = 8 + ControlServerSideSortingCodeAdminLimitExceeded ControlServerSideSortingCode = 11 + ControlServerSideSortingCodeNoSuchAttribute ControlServerSideSortingCode = 16 + ControlServerSideSortingCodeInappropriateMatching ControlServerSideSortingCode = 18 + ControlServerSideSortingCodeInsufficientAccessRights ControlServerSideSortingCode = 50 + ControlServerSideSortingCodeBusy ControlServerSideSortingCode = 51 + ControlServerSideSortingCodeUnwillingToPerform ControlServerSideSortingCode = 53 + ControlServerSideSortingCodeOther ControlServerSideSortingCode = 80 +) + +var ControlServerSideSortingCodes = []ControlServerSideSortingCode{ + ControlServerSideSortingCodeSuccess, + ControlServerSideSortingCodeOperationsError, + ControlServerSideSortingCodeTimeLimitExceeded, + ControlServerSideSortingCodeStrongAuthRequired, + ControlServerSideSortingCodeAdminLimitExceeded, + ControlServerSideSortingCodeNoSuchAttribute, + ControlServerSideSortingCodeInappropriateMatching, + ControlServerSideSortingCodeInsufficientAccessRights, + ControlServerSideSortingCodeBusy, + ControlServerSideSortingCodeUnwillingToPerform, + ControlServerSideSortingCodeOther, +} + +type ControlServerSideSortingCode int64 + +// Valid test the code contained in the control against the ControlServerSideSortingCodes slice and return an error if the code is unknown. +func (c ControlServerSideSortingCode) Valid() error { + for _, validRet := range ControlServerSideSortingCodes { + if c == validRet { + return nil + } + } + return fmt.Errorf("unknown return code : %d", c) +} + +func NewControlServerSideSortingResult(pkt *ber.Packet) (*ControlServerSideSortingResult, error) { + control := &ControlServerSideSortingResult{} + + if pkt == nil || len(pkt.Children) == 0 { + return nil, fmt.Errorf("bad packet") + } + + codeInt, err := ber.ParseInt64(pkt.Children[0].Data.Bytes()) + if err != nil { + return nil, err + } + + code := ControlServerSideSortingCode(codeInt) + if err := code.Valid(); err != nil { + return nil, err + } + + return control, nil +} + +type ControlServerSideSortingResult struct { + Criticality bool + + Result ControlServerSideSortingCode + + // Not populated for now. I can't get openldap to send me this value, so I think this is specific to other directory server + // AttributeType string +} + +func (control *ControlServerSideSortingResult) GetControlType() string { + return ControlTypeServerSideSortingResult +} + +func (c *ControlServerSideSortingResult) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "SortResult sequence") + sortResult := ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(c.Result), "SortResult") + packet.AppendChild(sortResult) + + return packet +} + +func (c *ControlServerSideSortingResult) String() string { + return fmt.Sprintf( + "Control Type: %s (%q) Criticality:%t ResultCode:%+v", + "Server Side Sorting Result", + c.GetControlType(), + c.Criticality, + c.Result, + ) +} + // Mode for ControlTypeSyncRequest type ControlSyncRequestMode int64 @@ -752,9 +950,12 @@ func (c *ControlSyncRequest) GetControlType() string { func (c *ControlSyncRequest) Encode() *ber.Packet { _mode := int64(c.Mode) mode := ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, _mode, "Mode") - cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie") - cookie.Value = c.Cookie - cookie.Data.Write(c.Cookie) + var cookie *ber.Packet + if len(c.Cookie) > 0 { + cookie = ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie") + cookie.Value = c.Cookie + cookie.Data.Write(c.Cookie) + } reloadHint := ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.ReloadHint, "Reload Hint") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") @@ -764,7 +965,9 @@ func (c *ControlSyncRequest) Encode() *ber.Packet { val := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Sync Request)") seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Sync Request Value") seq.AppendChild(mode) - seq.AppendChild(cookie) + if cookie != nil { + seq.AppendChild(cookie) + } seq.AppendChild(reloadHint) val.AppendChild(seq) diff --git a/control_test.go b/control_test.go index 118f9644..5f43ccb2 100644 --- a/control_test.go +++ b/control_test.go @@ -219,3 +219,56 @@ func TestDecodeControl(t *testing.T) { }) } } + +func TestControlServerSideSortingDecoding(t *testing.T) { + control := NewControlServerSideSortingWithSortKeys([]*SortKey{{ + MatchingRule: "foo", + AttributeType: "foobar", + Reverse: true, + }, { + MatchingRule: "foo", + AttributeType: "foobar", + Reverse: false, + }, { + MatchingRule: "", + AttributeType: "", + Reverse: false, + }, { + MatchingRule: "totoRule", + AttributeType: "", + Reverse: false, + }, { + MatchingRule: "", + AttributeType: "totoType", + Reverse: false, + }}) + + controlDecoded, err := NewControlServerSideSorting(control.Encode()) + if err != nil { + t.Fatal(err) + } + + if control.GetControlType() != controlDecoded.GetControlType() { + t.Fatalf("control type mismatch: control:%s - decoded:%s", control.GetControlType(), controlDecoded.GetControlType()) + } + + if len(control.SortKeys) != len(controlDecoded.SortKeys) { + t.Fatalf("sort keys length mismatch (control: %d - decoded: %d)", len(control.SortKeys), len(controlDecoded.SortKeys)) + } + + for i, sk := range control.SortKeys { + dsk := controlDecoded.SortKeys[i] + + if sk.AttributeType != dsk.AttributeType { + t.Fatalf("attribute type mismatch for sortkey %d", i) + } + + if sk.MatchingRule != dsk.MatchingRule { + t.Fatalf("matching rule mismatch for sortkey %d", i) + } + + if sk.Reverse != dsk.Reverse { + t.Fatalf("reverse mismtach for sortkey %d", i) + } + } +} diff --git a/search.go b/search.go index ca86cc9f..4eb10762 100644 --- a/search.go +++ b/search.go @@ -220,6 +220,7 @@ func readTag(f reflect.StructField) (string, bool) { // // // Time is parsed with the generalizedTime spec into a time.Time // Created time.Time `ldap:"createdTimestamp"` +// // // *DN is parsed with the ParseDN // Owner *ldap.DN `ldap:"owner"` // @@ -232,6 +233,7 @@ func readTag(f reflect.StructField) (string, bool) { // UserAccountControl uint32 `ldap:"userPrincipalName"` // } // user := UserEntry{} +// // if err := result.Unmarshal(&user); err != nil { // // ... // } @@ -376,7 +378,7 @@ func (s *SearchResult) appendTo(r *SearchResult) { r.Controls = append(r.Controls, s.Controls...) } -// SearchSingleResult holds the server's single response to a search request +// SearchSingleResult holds the server's single entry response to a search request type SearchSingleResult struct { // Entry is the returned entry Entry *Entry diff --git a/v3/conn.go b/v3/conn.go index 4de22fda..6d083621 100644 --- a/v3/conn.go +++ b/v3/conn.go @@ -329,7 +329,7 @@ func (l *Conn) nextMessageID() int64 { } // GetLastError returns the last recorded error from goroutines like processMessages and reader. -// // Only the last recorded error will be returned. +// Only the last recorded error will be returned. func (l *Conn) GetLastError() error { l.messageMutex.Lock() defer l.messageMutex.Unlock()