diff --git a/client.go b/client.go index b438d254..5799f39b 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "time" ) @@ -32,6 +33,7 @@ type Client interface { PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error) } diff --git a/conn.go b/conn.go index d39213b4..474d494b 100644 --- a/conn.go +++ b/conn.go @@ -327,6 +327,8 @@ func (l *Conn) nextMessageID() int64 { // GetLastError returns the last recorded error from goroutines like processMessages and reader. // Only the last recorded error will be returned. func (l *Conn) GetLastError() error { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() return l.err } diff --git a/examples_test.go b/examples_test.go index d788e4f5..61f16197 100644 --- a/examples_test.go +++ b/examples_test.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -50,6 +51,35 @@ func ExampleConn_Search() { } } +// This example demonstrates how to search asynchronously +func ExampleConn_SearchAsync() { + l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + "dc=example,dc=com", // The base dn to search + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + "(&(objectClass=organizationalPerson))", // The filter to apply + []string{"dn", "cn"}, // A list attributes to retrieve + nil, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + entry := r.Entry() + fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN) + } + if err := r.Err(); err != nil { + log.Fatal(err) + } +} + // This example demonstrates how to start a TLS connection func ExampleConn_StartTLS() { l, err := DialURL("ldap://ldap.example.com:389") diff --git a/ldap_test.go b/ldap_test.go index 61417fd5..5b96e039 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -1,7 +1,9 @@ package ldap import ( + "context" "crypto/tls" + "log" "testing" ber "github.com/go-asn1-ber/asn1-ber" @@ -344,3 +346,67 @@ func TestEscapeDN(t *testing.T) { }) } } + +func TestSearchAsync(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + + srs := make([]*Entry, 0) + ctx := context.Background() + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + srs = append(srs, r.Entry()) + } + if err := r.Err(); err != nil { + log.Fatal(err) + } + + t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs)) +} + +func TestSearchAsyncAndCancel(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + + cancelNum := 10 + srs := make([]*Entry, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := l.SearchAsync(ctx, searchRequest, 0) + for r.Next() { + srs = append(srs, r.Entry()) + if len(srs) == cancelNum { + cancel() + } + } + if err := r.Err(); err != nil { + log.Fatal(err) + } + + if len(srs) > cancelNum+3 { + // the cancellation process is asynchronous, + // so it might get some entries after calling cancel() + t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3) + } + t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) +} diff --git a/response.go b/response.go new file mode 100644 index 00000000..81d97d9b --- /dev/null +++ b/response.go @@ -0,0 +1,182 @@ +package ldap + +import ( + "context" + "errors" + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// Response defines an interface to get data from an LDAP server +type Response interface { + Entry() *Entry + Referral() string + Controls() []Control + Err() error + Next() bool +} + +type searchResponse struct { + conn *Conn + ch chan *SearchSingleResult + + entry *Entry + referral string + controls []Control + err error +} + +// Entry returns an entry from the given search request +func (r *searchResponse) Entry() *Entry { + return r.entry +} + +// Referral returns a referral from the given search request +func (r *searchResponse) Referral() string { + return r.referral +} + +// Controls returns controls from the given search request +func (r *searchResponse) Controls() []Control { + return r.controls +} + +// Err returns an error when the given search request was failed +func (r *searchResponse) Err() error { + return r.err +} + +// Next returns whether next data exist or not +func (r *searchResponse) Next() bool { + res, ok := <-r.ch + if !ok { + return false + } + if res == nil { + return false + } + r.err = res.Error + if r.err != nil { + return false + } + r.err = r.conn.GetLastError() + if r.err != nil { + return false + } + r.entry = res.Entry + r.referral = res.Referral + r.controls = res.Controls + return true +} + +func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) { + go func() { + defer func() { + close(r.ch) + if err := recover(); err != nil { + r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err) + } + }() + + if r.conn.IsClosing() { + return + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + r.conn.Debug.PrintPacket(packet) + + msgCtx, err := r.conn.sendMessage(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + defer r.conn.finishMessage(msgCtx) + + foundSearchSingleResultDone := false + for !foundSearchSingleResultDone { + select { + case <-ctx.Done(): + r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) + return + default: + r.conn.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + r.ch <- &SearchSingleResult{Error: err} + return + } + packet, err = packetResponse.ReadPacket() + r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + + if r.conn.Debug { + if err := addLDAPDescriptions(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + r.ch <- &SearchSingleResult{ + Entry: &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + }, + } + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + if len(packet.Children) == 3 { + result := &SearchSingleResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + werr := fmt.Errorf("failed to decode child control: %w", err) + r.ch <- &SearchSingleResult{Error: werr} + return + } + result.Controls = append(result.Controls, decodedChild) + } + r.ch <- result + } + foundSearchSingleResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + r.ch <- &SearchSingleResult{Referral: ref} + } + } + } + r.conn.Debug.Printf("%d: returning", msgCtx.id) + }() +} + +func newSearchResponse(conn *Conn, bufferSize int) *searchResponse { + var ch chan *SearchSingleResult + if bufferSize > 0 { + ch = make(chan *SearchSingleResult, bufferSize) + } else { + ch = make(chan *SearchSingleResult) + } + return &searchResponse{ + conn: conn, + ch: ch, + } +} diff --git a/search.go b/search.go index ef3119b9..3d8d9e70 100644 --- a/search.go +++ b/search.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "errors" "fmt" "reflect" @@ -375,6 +376,28 @@ func (s *SearchResult) appendTo(r *SearchResult) { r.Controls = append(r.Controls, s.Controls...) } +// SearchSingleResult holds the server's single response to a search request +type SearchSingleResult struct { + // Entry is the returned entry + Entry *Entry + // Referral is the returned referral + Referral string + // Controls are the returned controls + Controls []Control + // Error is set when the search request was failed + Error error +} + +// Print outputs a human-readable description +func (s *SearchSingleResult) Print() { + s.Entry.Print() +} + +// PrettyPrint outputs a human-readable description with indenting +func (s *SearchSingleResult) PrettyPrint(indent int) { + s.Entry.PrettyPrint(indent) +} + // SearchRequest represents a search request to send to the server type SearchRequest struct { BaseDN string @@ -559,6 +582,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } +// SearchAsync performs a search request and returns all search results asynchronously. +// This means you get all results until an error happens (or the search successfully finished), +// e.g. for size / time limited requests all are recieved until the limit is reached. +// To stop the search, call cancel function returned context. +func (l *Conn) SearchAsync( + ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response { + r := newSearchResponse(l, bufferSize) + r.start(ctx, searchRequest) + return r +} + // unpackAttributes will extract all given LDAP attributes and it's values // from the ber.Packet func unpackAttributes(children []*ber.Packet) []*EntryAttribute { diff --git a/v3/client.go b/v3/client.go index b438d254..5799f39b 100644 --- a/v3/client.go +++ b/v3/client.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "time" ) @@ -32,6 +33,7 @@ type Client interface { PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error) } diff --git a/v3/conn.go b/v3/conn.go index 3ed80883..a42a9697 100644 --- a/v3/conn.go +++ b/v3/conn.go @@ -327,6 +327,8 @@ func (l *Conn) nextMessageID() int64 { // GetLastError returns the last recorded error from goroutines like processMessages and reader. // // Only the last recorded error will be returned. func (l *Conn) GetLastError() error { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() return l.err } diff --git a/v3/examples_test.go b/v3/examples_test.go index d788e4f5..61f16197 100644 --- a/v3/examples_test.go +++ b/v3/examples_test.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -50,6 +51,35 @@ func ExampleConn_Search() { } } +// This example demonstrates how to search asynchronously +func ExampleConn_SearchAsync() { + l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + "dc=example,dc=com", // The base dn to search + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + "(&(objectClass=organizationalPerson))", // The filter to apply + []string{"dn", "cn"}, // A list attributes to retrieve + nil, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + entry := r.Entry() + fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN) + } + if err := r.Err(); err != nil { + log.Fatal(err) + } +} + // This example demonstrates how to start a TLS connection func ExampleConn_StartTLS() { l, err := DialURL("ldap://ldap.example.com:389") diff --git a/v3/ldap_test.go b/v3/ldap_test.go index 61417fd5..5b96e039 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -1,7 +1,9 @@ package ldap import ( + "context" "crypto/tls" + "log" "testing" ber "github.com/go-asn1-ber/asn1-ber" @@ -344,3 +346,67 @@ func TestEscapeDN(t *testing.T) { }) } } + +func TestSearchAsync(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + + srs := make([]*Entry, 0) + ctx := context.Background() + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + srs = append(srs, r.Entry()) + } + if err := r.Err(); err != nil { + log.Fatal(err) + } + + t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs)) +} + +func TestSearchAsyncAndCancel(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + + cancelNum := 10 + srs := make([]*Entry, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := l.SearchAsync(ctx, searchRequest, 0) + for r.Next() { + srs = append(srs, r.Entry()) + if len(srs) == cancelNum { + cancel() + } + } + if err := r.Err(); err != nil { + log.Fatal(err) + } + + if len(srs) > cancelNum+3 { + // the cancellation process is asynchronous, + // so it might get some entries after calling cancel() + t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3) + } + t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) +} diff --git a/v3/response.go b/v3/response.go new file mode 100644 index 00000000..81d97d9b --- /dev/null +++ b/v3/response.go @@ -0,0 +1,182 @@ +package ldap + +import ( + "context" + "errors" + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// Response defines an interface to get data from an LDAP server +type Response interface { + Entry() *Entry + Referral() string + Controls() []Control + Err() error + Next() bool +} + +type searchResponse struct { + conn *Conn + ch chan *SearchSingleResult + + entry *Entry + referral string + controls []Control + err error +} + +// Entry returns an entry from the given search request +func (r *searchResponse) Entry() *Entry { + return r.entry +} + +// Referral returns a referral from the given search request +func (r *searchResponse) Referral() string { + return r.referral +} + +// Controls returns controls from the given search request +func (r *searchResponse) Controls() []Control { + return r.controls +} + +// Err returns an error when the given search request was failed +func (r *searchResponse) Err() error { + return r.err +} + +// Next returns whether next data exist or not +func (r *searchResponse) Next() bool { + res, ok := <-r.ch + if !ok { + return false + } + if res == nil { + return false + } + r.err = res.Error + if r.err != nil { + return false + } + r.err = r.conn.GetLastError() + if r.err != nil { + return false + } + r.entry = res.Entry + r.referral = res.Referral + r.controls = res.Controls + return true +} + +func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) { + go func() { + defer func() { + close(r.ch) + if err := recover(); err != nil { + r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err) + } + }() + + if r.conn.IsClosing() { + return + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + r.conn.Debug.PrintPacket(packet) + + msgCtx, err := r.conn.sendMessage(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + defer r.conn.finishMessage(msgCtx) + + foundSearchSingleResultDone := false + for !foundSearchSingleResultDone { + select { + case <-ctx.Done(): + r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) + return + default: + r.conn.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + r.ch <- &SearchSingleResult{Error: err} + return + } + packet, err = packetResponse.ReadPacket() + r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + + if r.conn.Debug { + if err := addLDAPDescriptions(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + r.ch <- &SearchSingleResult{ + Entry: &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + }, + } + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + if len(packet.Children) == 3 { + result := &SearchSingleResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + werr := fmt.Errorf("failed to decode child control: %w", err) + r.ch <- &SearchSingleResult{Error: werr} + return + } + result.Controls = append(result.Controls, decodedChild) + } + r.ch <- result + } + foundSearchSingleResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + r.ch <- &SearchSingleResult{Referral: ref} + } + } + } + r.conn.Debug.Printf("%d: returning", msgCtx.id) + }() +} + +func newSearchResponse(conn *Conn, bufferSize int) *searchResponse { + var ch chan *SearchSingleResult + if bufferSize > 0 { + ch = make(chan *SearchSingleResult, bufferSize) + } else { + ch = make(chan *SearchSingleResult) + } + return &searchResponse{ + conn: conn, + ch: ch, + } +} diff --git a/v3/search.go b/v3/search.go index 9c0ccd07..afac768c 100644 --- a/v3/search.go +++ b/v3/search.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "errors" "fmt" "reflect" @@ -377,6 +378,28 @@ func (s *SearchResult) appendTo(r *SearchResult) { r.Controls = append(r.Controls, s.Controls...) } +// SearchSingleResult holds the server's single entry response to a search request +type SearchSingleResult struct { + // Entry is the returned entry + Entry *Entry + // Referral is the returned referral + Referral string + // Controls are the returned controls + Controls []Control + // Error is set when the search request was failed + Error error +} + +// Print outputs a human-readable description +func (s *SearchSingleResult) Print() { + s.Entry.Print() +} + +// PrettyPrint outputs a human-readable description with indenting +func (s *SearchSingleResult) PrettyPrint(indent int) { + s.Entry.PrettyPrint(indent) +} + // SearchRequest represents a search request to send to the server type SearchRequest struct { BaseDN string @@ -561,6 +584,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } +// SearchAsync performs a search request and returns all search results asynchronously. +// This means you get all results until an error happens (or the search successfully finished), +// e.g. for size / time limited requests all are recieved until the limit is reached. +// To stop the search, call cancel function returned context. +func (l *Conn) SearchAsync( + ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response { + r := newSearchResponse(l, bufferSize) + r.start(ctx, searchRequest) + return r +} + // unpackAttributes will extract all given LDAP attributes and it's values // from the ber.Packet func unpackAttributes(children []*ber.Packet) []*EntryAttribute {