Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fixed message processing deadlocks and added mutex for closing function
  • Loading branch information
mmitton committed Feb 18, 2011
1 parent ca949cd commit 611f66a
Showing 1 changed file with 51 additions and 20 deletions.
71 changes: 51 additions & 20 deletions conn.go
Expand Up @@ -11,6 +11,7 @@ import (
"fmt"
"net"
"os"
"sync"
)

// LDAP Connection
Expand All @@ -22,6 +23,8 @@ type Conn struct {
chanResults map[ uint64 ] chan *ber.Packet
chanProcessMessage chan *messagePacket
chanMessageID chan uint64

closeLock sync.Mutex
}

// Dial connects to the given address on the given network using net.Dial
Expand Down Expand Up @@ -87,11 +90,10 @@ func (l *Conn) start() {

// Close closes the connection.
func (l *Conn) Close() *Error {
if l.chanProcessMessage != nil {
message_packet := &messagePacket{ Op: MessageQuit }
l.chanProcessMessage <- message_packet
l.chanProcessMessage = nil
}
l.closeLock.Lock()
defer l.closeLock.Unlock()

l.sendProcessMessage( &messagePacket{ Op: MessageQuit } )

if l.conn != nil {
err := l.conn.Close()
Expand All @@ -104,9 +106,10 @@ func (l *Conn) Close() *Error {
}

// Returns the next available messageID
func (l *Conn) nextMessageID() uint64 {
messageID := <-l.chanMessageID
return messageID
func (l *Conn) nextMessageID() (messageID uint64) {
defer func() { if r := recover(); r != nil { messageID = 0 } }()
messageID = <-l.chanMessageID
return
}

// StartTLS sends the command to start a TLS session and then creates a new TLS Client
Expand Down Expand Up @@ -170,12 +173,12 @@ func (l *Conn) sendMessage( p *ber.Packet ) (out chan *ber.Packet, err *Error) {
message_id := p.Children[ 0 ].Value.(uint64)
out = make(chan *ber.Packet)

message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out }
if l.chanProcessMessage == nil {
err = NewError( ErrorNetwork, os.NewError( "Connection closed" ) )
return
}
l.chanProcessMessage <- message_packet
message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out }
l.sendProcessMessage( message_packet )
return
}

Expand Down Expand Up @@ -208,17 +211,32 @@ func (l *Conn) processMessages() {
fmt.Printf( "Sending message %d\n", message_packet.MessageID )
}
l.chanResults[ message_packet.MessageID ] = message_packet.Channel
l.conn.Write( message_packet.Packet.Bytes() )
buf := message_packet.Packet.Bytes()
for len( buf ) > 0 {
n, err := l.conn.Write( buf )
if err != nil {
if l.Debug {
fmt.Printf( "Error Sending Message: %s\n", err.String() )
}
return
}
if n == len( buf ) {
break
}
buf = buf[n:]
}
case MessageResponse:
// Pass back to waiting goroutine
if l.Debug {
fmt.Printf( "Receiving message %d\n", message_packet.MessageID )
}
chanResult := l.chanResults[ message_packet.MessageID ]
if chanResult == nil {
fmt.Printf( "Unexpected Message Result: %d", message_id )
fmt.Printf( "Unexpected Message Result: %d\n", message_id )
ber.PrintPacket( message_packet.Packet )
} else {
chanResult <- message_packet.Packet
go func() { chanResult <- message_packet.Packet }()
// chanResult <- message_packet.Packet
}
case MessageFinish:
// Remove from message list
Expand All @@ -232,6 +250,7 @@ func (l *Conn) processMessages() {
}

func (l *Conn) closeAllChannels() {
fmt.Printf( "closeAllChannels\n" )
for MessageID, Channel := range l.chanResults {
if l.Debug {
fmt.Printf( "Closing channel for MessageID %d\n", MessageID );
Expand All @@ -241,30 +260,42 @@ func (l *Conn) closeAllChannels() {
}
close( l.chanMessageID )
l.chanMessageID = nil

close( l.chanProcessMessage )
l.chanProcessMessage = nil
}

func (l *Conn) finishMessage( MessageID uint64 ) {
message_packet := &messagePacket{ Op: MessageFinish, MessageID: MessageID }
if l.chanProcessMessage != nil {
l.chanProcessMessage <- message_packet
}
l.sendProcessMessage( message_packet )
}

func (l *Conn) reader() {
defer l.Close()
for {
p, err := ber.ReadPacket( l.conn )
if err != nil {
if l.Debug {
fmt.Printf( "ldap.reader: %s\n", err.String() )
}
break
return
}

addLDAPDescriptions( p )

message_id := p.Children[ 0 ].Value.(uint64)
message_packet := &messagePacket{ Op: MessageResponse, MessageID: message_id, Packet: p }
l.chanProcessMessage <- message_packet
if l.chanProcessMessage != nil {
l.chanProcessMessage <- message_packet
} else {
fmt.Printf( "ldap.reader: Cannot return message\n" )
return
}
}

l.Close()
}

func (l *Conn) sendProcessMessage( message *messagePacket ) {
if l.chanProcessMessage != nil {
go func() { l.chanProcessMessage <- message }()
}
}

0 comments on commit 611f66a

Please sign in to comment.