From e5aaa0fa567b123f68c35a0a5e2b00e260fe771d Mon Sep 17 00:00:00 2001 From: Josh Hawn Date: Fri, 17 Jun 2016 13:39:25 -0700 Subject: [PATCH] Use Message Context This patch introduces a new type called `massageContext` which is now the return value of `(*Conn).sendMessage()`. The message context object still contains a channel from which methods like, Add(), Bind(), Search(), etc., will receive response packets. It also has a field which holds the message ID as well as a `done` channel which is used to prevent deadlock in the `processMessages()` goroutine. This is accomplished by also changing the `(*Conn).finishMessage()` method to take a message context and close this `done` channel before sending a `MessageFinish` packet to the `processMessages()` goroutine. The `processMessages()` goroutine now has a `massageContexts` map which replaces the `chanResults` map. Now, rather than sending response packets only on the response channels, the `messageContext` has its own `sendResponse()` method which uses a switch that blocks on sending a response packet *or* waiting for its `done` channel to be closed by `finishMessage()`. Docker-DCO-1.1-Signed-off-by: Josh Hawn (github: jlhawn) --- add.go | 23 ++-- bind.go | 34 ++---- compare.go | 18 ++- conn.go | 114 +++++++++++-------- conn_test.go | 295 +++++++++++++++++++++++++++++++++++++++++++++++- del.go | 20 ++-- modify.go | 20 ++-- passwdmodify.go | 19 ++-- search.go | 20 ++-- 9 files changed, 416 insertions(+), 147 deletions(-) diff --git a/add.go b/add.go index 48db8212..7e00cbcc 100644 --- a/add.go +++ b/add.go @@ -21,8 +21,6 @@ type Attribute struct { Vals []string } - - func (a *Attribute) encode() *ber.Packet { seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.Type, "Type")) @@ -39,7 +37,6 @@ type AddRequest struct { Attributes []Attribute } - func (a AddRequest) encode() *ber.Packet { request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.DN, "DN")) @@ -63,29 +60,25 @@ func NewAddRequest(dn string) *AddRequest { } func (l *Conn) Add(addRequest *AddRequest) error { - messageID := l.nextMessageID() packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) packet.AppendChild(addRequest.encode()) l.Debug.PrintPacket(packet) - channel, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(packet) if err != nil { return err } - if channel == nil { - return NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) + defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", messageID) - packetResponse, ok := <-channel + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses if !ok { - return NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return err } @@ -106,6 +99,6 @@ func (l *Conn) Add(addRequest *AddRequest) error { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - l.Debug.Printf("%d: returning", messageID) + l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/bind.go b/bind.go index ae68eb48..e0a7afc5 100644 --- a/bind.go +++ b/bind.go @@ -40,10 +40,8 @@ func (bindRequest *SimpleBindRequest) encode() *ber.Packet { } func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) { - messageID := l.nextMessageID() - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) encodedBindRequest := simpleBindRequest.encode() packet.AppendChild(encodedBindRequest) @@ -51,21 +49,18 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu ber.PrintPacket(packet) } - channel, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(packet) if err != nil { return nil, err } - if channel == nil { - return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) + defer l.finishMessage(msgCtx) - packetResponse, ok := <-channel + packetResponse, ok := <-msgCtx.responses if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return nil, err } @@ -96,10 +91,8 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu } func (l *Conn) Bind(username, password string) error { - messageID := l.nextMessageID() - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name")) @@ -110,21 +103,18 @@ func (l *Conn) Bind(username, password string) error { ber.PrintPacket(packet) } - channel, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(packet) if err != nil { return err } - if channel == nil { - return NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) + defer l.finishMessage(msgCtx) - packetResponse, ok := <-channel + packetResponse, ok := <-msgCtx.responses if !ok { - return NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return err } diff --git a/compare.go b/compare.go index dfe728ba..cc6d2af5 100644 --- a/compare.go +++ b/compare.go @@ -33,9 +33,8 @@ import ( // Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise // false with any error that occurs if any. func (l *Conn) Compare(dn, attribute, value string) (bool, error) { - messageID := l.nextMessageID() packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request") request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, dn, "DN")) @@ -48,22 +47,19 @@ func (l *Conn) Compare(dn, attribute, value string) (bool, error) { l.Debug.PrintPacket(packet) - channel, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(packet) if err != nil { return false, err } - if channel == nil { - return false, NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) + defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", messageID) - packetResponse, ok := <-channel + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses if !ok { - return false, NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return false, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return false, err } diff --git a/conn.go b/conn.go index a037aff9..538166e7 100644 --- a/conn.go +++ b/conn.go @@ -36,11 +36,29 @@ func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { return pr.Packet, pr.Error } +type messageContext struct { + id int64 + done chan struct{} + responses chan *PacketResponse +} + +// sendResponse should only be called within the processMessages() loop which +// is also responsible for closing the responses channel. +func (msgCtx *messageContext) sendResponse(packet *PacketResponse) { + select { + case msgCtx.responses <- packet: + // Successfully sent packet to message handler. + case <-msgCtx.done: + // The request handler is done and will not receive more + // packets. + } +} + type messagePacket struct { Op int MessageID int64 Packet *ber.Packet - Channel chan *PacketResponse + Context *messageContext } type sendMessageFlags uint @@ -58,7 +76,7 @@ type Conn struct { isStartingTLS bool Debug debugging chanConfirm chan bool - chanResults map[int64]chan *PacketResponse + messageContexts map[int64]*messageContext chanMessage chan *messagePacket chanMessageID chan int64 wgSender sync.WaitGroup @@ -112,13 +130,13 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { // NewConn returns a new Conn using conn for network I/O. func NewConn(conn net.Conn, isTLS bool) *Conn { return &Conn{ - conn: conn, - chanConfirm: make(chan bool), - chanMessageID: make(chan int64), - chanMessage: make(chan *messagePacket, 10), - chanResults: map[int64]chan *PacketResponse{}, - requestTimeout: 0, - isTLS: isTLS, + conn: conn, + chanConfirm: make(chan bool), + chanMessageID: make(chan int64), + chanMessage: make(chan *messagePacket, 10), + messageContexts: map[int64]*messageContext{}, + requestTimeout: 0, + isTLS: isTLS, } } @@ -168,35 +186,31 @@ func (l *Conn) nextMessageID() int64 { // StartTLS sends the command to start a TLS session and then creates a new TLS Client func (l *Conn) StartTLS(config *tls.Config) error { - messageID := l.nextMessageID() - if l.isTLS { return NewError(ErrorNetwork, errors.New("ldap: already encrypted")) } packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) packet.AppendChild(request) l.Debug.PrintPacket(packet) - channel, err := l.sendMessageWithFlags(packet, startTLS) + msgCtx, err := l.sendMessageWithFlags(packet, startTLS) if err != nil { return err } - if channel == nil { - return NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } + defer l.finishMessage(msgCtx) + + l.Debug.Printf("%d: waiting for response", msgCtx.id) - l.Debug.Printf("%d: waiting for response", messageID) - defer l.finishMessage(messageID) - packetResponse, ok := <-channel + packetResponse, ok := <-msgCtx.responses if !ok { - return NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return err } @@ -227,11 +241,11 @@ func (l *Conn) StartTLS(config *tls.Config) error { return nil } -func (l *Conn) sendMessage(packet *ber.Packet) (chan *PacketResponse, error) { +func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { return l.sendMessageWithFlags(packet, 0) } -func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (chan *PacketResponse, error) { +func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { if l.isClosing { return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) } @@ -253,18 +267,25 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) l.messageMutex.Unlock() - out := make(chan *PacketResponse) + responses := make(chan *PacketResponse) + messageID := packet.Children[0].Value.(int64) message := &messagePacket{ Op: MessageRequest, - MessageID: packet.Children[0].Value.(int64), + MessageID: messageID, Packet: packet, - Channel: out, + Context: &messageContext{ + id: messageID, + done: make(chan struct{}), + responses: responses, + }, } l.sendProcessMessage(message) - return out, nil + return message.Context, nil } -func (l *Conn) finishMessage(messageID int64) { +func (l *Conn) finishMessage(msgCtx *messageContext) { + close(msgCtx.done) + if l.isClosing { return } @@ -278,7 +299,7 @@ func (l *Conn) finishMessage(messageID int64) { message := &messagePacket{ Op: MessageFinish, - MessageID: messageID, + MessageID: msgCtx.id, } l.sendProcessMessage(message) } @@ -298,15 +319,15 @@ func (l *Conn) processMessages() { if err := recover(); err != nil { log.Printf("ldap: recovered panic in processMessages: %v", err) } - for messageID, channel := range l.chanResults { + for messageID, msgCtx := range l.messageContexts { // If we are closing due to an error, inform anyone who // is waiting about the error. if l.isClosing && l.closeErr != nil { - channel <- &PacketResponse{Error: l.closeErr} + msgCtx.sendResponse(&PacketResponse{Error: l.closeErr}) } l.Debug.Printf("Closing channel for MessageID %d", messageID) - close(channel) - delete(l.chanResults, messageID) + close(msgCtx.responses) + delete(l.messageContexts, messageID) } close(l.chanMessageID) l.chanConfirm <- true @@ -335,14 +356,14 @@ func (l *Conn) processMessages() { _, err := l.conn.Write(buf) if err != nil { l.Debug.Printf("Error Sending Message: %s", err.Error()) - message.Channel <- &PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)} - close(message.Channel) + message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}) + close(message.Context.responses) break } - // Only add to chanResults if we were able to + // Only add to messageContexts if we were able to // successfully write the message. - l.chanResults[message.MessageID] = message.Channel + l.messageContexts[message.MessageID] = message.Context // Add timeout if defined if l.requestTimeout > 0 { @@ -362,8 +383,8 @@ func (l *Conn) processMessages() { } case MessageResponse: l.Debug.Printf("Receiving message %d", message.MessageID) - if chanResult, ok := l.chanResults[message.MessageID]; ok { - chanResult <- &PacketResponse{message.Packet, nil} + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) } else { log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing) ber.PrintPacket(message.Packet) @@ -371,17 +392,17 @@ func (l *Conn) processMessages() { case MessageTimeout: // Handle the timeout by closing the channel // All reads will return immediately - if chanResult, ok := l.chanResults[message.MessageID]; ok { - chanResult <- &PacketResponse{message.Packet, errors.New("ldap: connection timed out")} + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { l.Debug.Printf("Receiving message timeout for %d", message.MessageID) - delete(l.chanResults, message.MessageID) - close(chanResult) + msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")}) + delete(l.messageContexts, message.MessageID) + close(msgCtx.responses) } case MessageFinish: l.Debug.Printf("Finished message %d", message.MessageID) - if chanResult, ok := l.chanResults[message.MessageID]; ok { - close(chanResult) - delete(l.chanResults, message.MessageID) + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + delete(l.messageContexts, message.MessageID) + close(msgCtx.responses) } } } @@ -431,6 +452,5 @@ func (l *Conn) reader() { if !l.sendProcessMessage(message) { return } - } } diff --git a/conn_test.go b/conn_test.go index 8394e533..10766bbd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,9 +1,14 @@ package ldap import ( + "bytes" + "errors" + "io" "net" "net/http" "net/http/httptest" + "runtime" + "sync" "testing" "time" @@ -27,19 +32,20 @@ func TestUnresponsiveConnection(t *testing.T) { defer conn.Close() // Mock a packet - messageID := conn.nextMessageID() packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID")) bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) packet.AppendChild(bindRequest) // Send packet and test response - channel, err := conn.sendMessage(packet) + msgCtx, err := conn.sendMessage(packet) if err != nil { t.Fatalf("error sending message: %v", err) } - packetResponse, ok := <-channel + defer conn.finishMessage(msgCtx) + + packetResponse, ok := <-msgCtx.responses if !ok { t.Fatalf("no PacketResponse in response channel") } @@ -51,3 +57,284 @@ func TestUnresponsiveConnection(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +// TestFinishMessage tests that we do not enter deadlock when a goroutine makes +// a request but does not handle all responses from the server. +func TestConn(t *testing.T) { + ptc := newPacketTranslatorConn() + defer ptc.Close() + + conn := NewConn(ptc, false) + conn.Start() + + // Test sending 5 different requests in series. Ensure that we can + // get a response packet from the underlying connection and also + // ensure that we can gracefully ignore unhandled responses. + for i := 0; i < 5; i++ { + t.Logf("serial request %d", i) + // Create a message and make sure we can receive responses. + msgCtx := testSendRequest(t, ptc, conn) + testReceiveResponse(t, ptc, msgCtx) + + // Send a few unhandled responses and finish the message. + testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) + t.Logf("serial request %d done", i) + } + + // Test sending 5 different requests in parallel. + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + t.Logf("parallel request %d", i) + // Create a message and make sure we can receive responses. + msgCtx := testSendRequest(t, ptc, conn) + testReceiveResponse(t, ptc, msgCtx) + + // Send a few unhandled responses and finish the message. + testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) + t.Logf("parallel request %d done", i) + }(i) + } + wg.Wait() + + // We cannot run Close() in a defer because t.FailNow() will run it and + // it will block if the processMessage Loop is in a deadlock. + conn.Close() +} + +func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) { + var msgID int64 + runWithTimeout(t, time.Second, func() { + msgID = conn.nextMessageID() + }) + + requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID")) + + var err error + + runWithTimeout(t, time.Second, func() { + msgCtx, err = conn.sendMessage(requestPacket) + if err != nil { + t.Fatalf("unable to send request message: %s", err) + } + }) + + // We should now be able to get this request packet out from the other + // side. + runWithTimeout(t, time.Second, func() { + if _, err = ptc.ReceiveRequest(); err != nil { + t.Fatalf("unable to receive request packet: %s", err) + } + }) + + return msgCtx +} + +func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) { + // Send a mock response packet. + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) + + runWithTimeout(t, time.Second, func() { + if err := ptc.SendResponse(responsePacket); err != nil { + t.Fatalf("unable to send response packet: %s", err) + } + }) + + // We should be able to receive the packet from the connection. + runWithTimeout(t, time.Second, func() { + if _, ok := <-msgCtx.responses; !ok { + t.Fatal("response channel closed") + } + }) +} + +func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) { + // Send a mock response packet. + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) + + // Send extra responses but do not attempt to receive them on the + // client side. + for i := 0; i < numResponses; i++ { + runWithTimeout(t, time.Second, func() { + if err := ptc.SendResponse(responsePacket); err != nil { + t.Fatalf("unable to send response packet: %s", err) + } + }) + } + + // Finally, attempt to finish this message. + runWithTimeout(t, time.Second, func() { + conn.finishMessage(msgCtx) + }) +} + +func runWithTimeout(t *testing.T, timeout time.Duration, f func()) { + runtime.Gosched() + + done := make(chan struct{}) + go func() { + f() + close(done) + }() + + runtime.Gosched() + + select { + case <-done: // Success! + case <-time.After(timeout): + _, file, line, _ := runtime.Caller(1) + t.Fatalf("%s:%d timed out", file, line) + } +} + +// packetTranslatorConn is a helful type which can be used with various tests +// in this package. It implements the net.Conn interface to be used as an +// underlying connection for a *ldap.Conn. Most methods are no-ops but the +// Read() and Write() methods are able to translate ber-encoded packets for +// testing LDAP requests and responses. +// +// Test cases can simulate an LDAP server sending a response by calling the +// SendResponse() method with a ber-encoded LDAP response packet. Test cases +// can simulate an LDAP server receiving a request from a client by calling the +// ReceiveRequest() method which returns a ber-encoded LDAP request packet. +type packetTranslatorConn struct { + lock sync.Mutex + isClosed bool + + responseCond sync.Cond + requestCond sync.Cond + + responseBuf bytes.Buffer + requestBuf bytes.Buffer +} + +var errPacketTranslatorConnClosed = errors.New("connection closed") + +func newPacketTranslatorConn() *packetTranslatorConn { + conn := &packetTranslatorConn{} + conn.responseCond = sync.Cond{L: &conn.lock} + conn.requestCond = sync.Cond{L: &conn.lock} + + return conn +} + +// Read is called by the reader() loop to receive response packets. It will +// block until there are more packet bytes available or this connection is +// closed. +func (c *packetTranslatorConn) Read(b []byte) (n int, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + for !c.isClosed { + // Attempt to read data from the response buffer. If it fails + // with an EOF, wait and try again. + n, err = c.responseBuf.Read(b) + if err != io.EOF { + return n, err + } + + c.responseCond.Wait() + } + + return 0, errPacketTranslatorConnClosed +} + +// SendResponse writes the given response packet to the response buffer for +// this conection, signalling any goroutine waiting to read a response. +func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.isClosed { + return errPacketTranslatorConnClosed + } + + // Signal any goroutine waiting to read a response. + defer c.responseCond.Broadcast() + + // Writes to the buffer should always succeed. + c.responseBuf.Write(packet.Bytes()) + + return nil +} + +// Write is called by the processMessages() loop to send request packets. +func (c *packetTranslatorConn) Write(b []byte) (n int, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.isClosed { + return 0, errPacketTranslatorConnClosed + } + + // Signal any goroutine waiting to read a request. + defer c.requestCond.Broadcast() + + // Writes to the buffer should always succeed. + return c.requestBuf.Write(b) +} + +// ReceiveRequest attempts to read a request packet from this connection. It +// will block until it is able to read a full request packet or until this +// connection is closed. +func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) { + c.lock.Lock() + defer c.lock.Unlock() + + for !c.isClosed { + // Attempt to parse a request packet from the request buffer. + // If it fails with an unexpected EOF, wait and try again. + requestReader := bytes.NewReader(c.requestBuf.Bytes()) + packet, err := ber.ReadPacket(requestReader) + switch err { + case io.EOF, io.ErrUnexpectedEOF: + c.requestCond.Wait() + case nil: + // Advance the request buffer by the number of bytes + // read to decode the request packet. + c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len()) + return packet, nil + default: + return nil, err + } + } + + return nil, errPacketTranslatorConnClosed +} + +// Close closes this connection causing Read() and Write() calls to fail. +func (c *packetTranslatorConn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + c.isClosed = true + c.responseCond.Broadcast() + c.requestCond.Broadcast() + + return nil +} + +func (c *packetTranslatorConn) LocalAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *packetTranslatorConn) RemoteAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *packetTranslatorConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/del.go b/del.go index 5bb5a25d..81cbf8e1 100644 --- a/del.go +++ b/del.go @@ -32,9 +32,8 @@ func NewDelRequest(DN string, } func (l *Conn) Del(delRequest *DelRequest) error { - messageID := l.nextMessageID() packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) packet.AppendChild(delRequest.encode()) if delRequest.Controls != nil { packet.AppendChild(encodeControls(delRequest.Controls)) @@ -42,22 +41,19 @@ func (l *Conn) Del(delRequest *DelRequest) error { l.Debug.PrintPacket(packet) - channel, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(packet) if err != nil { return err } - if channel == nil { - return NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) + defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", messageID) - packetResponse, ok := <-channel + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses if !ok { - return NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return err } @@ -78,6 +74,6 @@ func (l *Conn) Del(delRequest *DelRequest) error { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - l.Debug.Printf("%d: returning", messageID) + l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/modify.go b/modify.go index 1c280596..7549f7dd 100644 --- a/modify.go +++ b/modify.go @@ -112,29 +112,25 @@ func NewModifyRequest( } func (l *Conn) Modify(modifyRequest *ModifyRequest) error { - messageID := l.nextMessageID() packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) packet.AppendChild(modifyRequest.encode()) l.Debug.PrintPacket(packet) - channel, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(packet) if err != nil { return err } - if channel == nil { - return NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) + defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", messageID) - packetResponse, ok := <-channel + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses if !ok { - return NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return err } @@ -155,6 +151,6 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - l.Debug.Printf("%d: returning", messageID) + l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/passwdmodify.go b/passwdmodify.go index 6d5ca975..4a358513 100644 --- a/passwdmodify.go +++ b/passwdmodify.go @@ -73,10 +73,8 @@ func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPasswo } func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) { - messageID := l.nextMessageID() - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) encodedPasswordModifyRequest, err := passwordModifyRequest.encode() if err != nil { @@ -86,24 +84,21 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa l.Debug.PrintPacket(packet) - channel, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(packet) if err != nil { return nil, err } - if channel == nil { - return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) + defer l.finishMessage(msgCtx) result := &PasswordModifyResult{} - l.Debug.Printf("%d: waiting for response", messageID) - packetResponse, ok := <-channel + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return nil, err } diff --git a/search.go b/search.go index 7e9495bd..21623ea5 100644 --- a/search.go +++ b/search.go @@ -342,9 +342,8 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) } func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { - messageID := l.nextMessageID() packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) // encode search request encodedSearchRequest, err := searchRequest.encode() if err != nil { @@ -358,14 +357,11 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { l.Debug.PrintPacket(packet) - channel, err := l.sendMessage(packet) + msgCtx, err := l.sendMessage(packet) if err != nil { return nil, err } - if channel == nil { - return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) + defer l.finishMessage(msgCtx) result := &SearchResult{ Entries: make([]*Entry, 0), @@ -374,13 +370,13 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { foundSearchResultDone := false for !foundSearchResultDone { - l.Debug.Printf("%d: waiting for response", messageID) - packetResponse, ok := <-channel + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed")) + return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) } packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", messageID, packet) + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) if err != nil { return nil, err } @@ -421,6 +417,6 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string)) } } - l.Debug.Printf("%d: returning", messageID) + l.Debug.Printf("%d: returning", msgCtx.id) return result, nil }