diff --git a/channel_test.go b/channel_test.go index 642a456..7d5522f 100644 --- a/channel_test.go +++ b/channel_test.go @@ -40,7 +40,7 @@ func TestPool_Get_Impl(t *testing.T) { t.Errorf("Get error: %s", err) } - _, ok := conn.(poolConn) + _, ok := conn.(*PoolConn) if !ok { t.Errorf("Conn is not of type poolConn") } @@ -119,6 +119,33 @@ func TestPool_Put(t *testing.T) { } } +func TestPool_PutUnusableConn(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + // ensure pool is not empty + conn, _ := p.Get() + conn.Close() + + poolSize := p.Len() + conn, _ = p.Get() + conn.Close() + if p.Len() != poolSize { + t.Errorf("Pool size is expected to be equal to initial size") + } + + conn, _ = p.Get() + if pc, ok := conn.(*PoolConn); !ok { + t.Errorf("impossible") + } else { + pc.MarkUnusable() + } + conn.Close() + if p.Len() != poolSize-1 { + t.Errorf("Pool size is expected to be initial_size - 1", p.Len(), poolSize-1) + } +} + func TestPool_UsedCapacity(t *testing.T) { p, _ := newChannelPool() defer p.Close() diff --git a/conn.go b/conn.go index b87ac71..bc2d705 100644 --- a/conn.go +++ b/conn.go @@ -2,21 +2,33 @@ package pool import "net" -// poolConn is a wrapper around net.Conn to modify the the behavior of +// PoolConn is a wrapper around net.Conn to modify the the behavior of // net.Conn's Close() method. -type poolConn struct { +type PoolConn struct { net.Conn - c *channelPool + c *channelPool + unusable bool } // Close() puts the given connects back to the pool instead of closing it. -func (p poolConn) Close() error { +func (p PoolConn) Close() error { + if p.unusable { + if p.Conn != nil { + return p.Conn.Close() + } + return nil + } return p.c.put(p.Conn) } +// MarkUnusable() marks the connection not usable any more, to let the pool close it instead of returning it to pool. +func (p *PoolConn) MarkUnusable() { + p.unusable = true +} + // newConn wraps a standard net.Conn to a poolConn net.Conn. func (c *channelPool) wrapConn(conn net.Conn) net.Conn { - p := poolConn{c: c} + p := &PoolConn{c: c} p.Conn = conn return p } diff --git a/conn_test.go b/conn_test.go index 6d287c6..55f9237 100644 --- a/conn_test.go +++ b/conn_test.go @@ -6,5 +6,5 @@ import ( ) func TestConn_Impl(t *testing.T) { - var _ net.Conn = new(poolConn) + var _ net.Conn = new(PoolConn) }