Skip to content

Commit

Permalink
Accept net.Listener instead of *net.TCPListener
Browse files Browse the repository at this point in the history
  • Loading branch information
cevatbarisyilmaz committed Sep 19, 2019
1 parent 9e379a9 commit 0e39bae
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 52 deletions.
18 changes: 2 additions & 16 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,10 @@ func Example_advanced() {
//Give privilege to the localhost to bypass all the limits
pListener.GivePrivilege(net.IPv4(127, 0, 0, 1))
for {
//Wait for a new connection for five minutes
err = pListener.SetDeadline(time.Now().Add(5 * time.Minute))
if err != nil {
log.Fatal(err)
}
//Wait for a new connection
conn, err := pListener.Accept()
if err != nil {
//Check if the error is actually about timeout
nErr, ok := err.(net.Error)
if !ok {
log.Fatal(err)
}
if !nErr.Timeout() {
log.Fatal(nErr)
}
//Log the lack of connectivity and keep accepting new connections
log.Println("No new connections for the last 5 minute")
continue
log.Fatal(err)
}
//Handle the connection
go func() {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/cevatbarisyilmaz/plistener

go 1.12

require golang.org/x/net v0.0.0-20190909003024-a7b16738d86b
6 changes: 3 additions & 3 deletions pconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
)

type pConn struct {
*net.TCPConn
net.Conn
listener *PListener
}

Expand All @@ -15,7 +15,7 @@ func (pConn *pConn) Close() error {
myListener.connCond.L.Lock()
myListener.currentConn--
myListener.connCond.L.Unlock()
tcpAddr := pConn.TCPConn.RemoteAddr().(*net.TCPAddr)
tcpAddr := pConn.Conn.RemoteAddr().(*net.TCPAddr)
var ip [16]byte
copy(ip[:], tcpAddr.IP.To16())
record := myListener.getRecord(ip)
Expand All @@ -34,5 +34,5 @@ func (pConn *pConn) Close() error {
}
myListener.connCond.Signal()
}
return pConn.TCPConn.Close()
return pConn.Conn.Close()
}
10 changes: 2 additions & 8 deletions pconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,8 @@ func TestPConn(t *testing.T) {
} else {
c2 = newConn
stop = func() {
err := c1.Close()
if err != nil {
t.Error(err)
}
err = c2.Close()
if err != nil {
t.Error(err)
}
_ = c1.Close()
_ = c2.Close()
err = listener.Close()
if err != nil {
t.Error(err)
Expand Down
39 changes: 19 additions & 20 deletions plistener.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import (
// DefaultMaxConn is maximum number of connections listeners willing to keep active.
// Changing DefaultMaxConn only affects future listeners.
// To change a currently active listener use MaxConn field of PListener struct.
var DefaultMaxConn = 1048
var DefaultMaxConn = 2048

// DefaultMaxConnSingleIP is maximum number of connections listeners willing to keep active with a single IP.
// Changing DefaultMaxConnSingleIP only affects future listeners.
// To change a currently active listener use MaxConnSingleIP field of PListener struct.
var DefaultMaxConnSingleIP = 16
var DefaultMaxConnSingleIP = 24

// Limiter is a slice of pairs of durations and amounts of maximum permitted new connections for IP addresses during the stated duration.
type Limiter []struct {
Expand All @@ -32,9 +32,9 @@ type Limiter []struct {
// Changing DefaultLimiter only affects future listeners.
// To change a currently active listener use SetLimiter function of PListener struct.
var DefaultLimiter = Limiter{
{time.Second, 32},
{time.Minute, 256},
{time.Hour, 2048},
{time.Second, 64},
{time.Minute, 512},
{time.Hour, 4096},
}

type ipRecord struct {
Expand All @@ -61,7 +61,7 @@ func newIPRecord() *ipRecord {

// PListener implements the net.Listener interface with protection against spams.
type PListener struct {
*net.TCPListener
net.Listener

// MaxConn is maximum number of connections listener willing to keep active.
// Default value is DefaultMaxConn.
Expand Down Expand Up @@ -90,9 +90,9 @@ type PListener struct {
}

// New returns a new PListener that wraps the given TCPListener with anti-spam capabilities.
func New(tcpListener *net.TCPListener) (pListener *PListener) {
func New(listener net.Listener) (pListener *PListener) {
pListener = &PListener{
TCPListener: tcpListener,
Listener: listener,
MaxConn: DefaultMaxConn,
MaxConnSingleIP: DefaultMaxConnSingleIP,
OnSpam: nil,
Expand All @@ -111,13 +111,13 @@ func New(tcpListener *net.TCPListener) (pListener *PListener) {
func (pListener *PListener) Accept() (conn net.Conn, err error) {
blocked := false
banned := false
var tcpConn *net.TCPConn
var c net.Conn
var raddr *net.TCPAddr
var record *ipRecord
defer func() {
if err != nil {
if tcpConn != nil {
_ = tcpConn.Close()
if c != nil {
_ = c.Close()
}
pListener.connCond.L.Lock()
pListener.currentConn--
Expand All @@ -128,7 +128,7 @@ func (pListener *PListener) Accept() (conn net.Conn, err error) {
}()
for {
if blocked {
err = tcpConn.Close()
err = c.Close()
if err != nil {
return
}
Expand All @@ -148,12 +148,12 @@ func (pListener *PListener) Accept() (conn net.Conn, err error) {
}
pListener.currentConn++
pListener.connCond.L.Unlock()
tcpConn, err = pListener.TCPListener.AcceptTCP()
c, err = pListener.Listener.Accept()
if err != nil {
return
}
now := time.Now()
raddr = tcpConn.RemoteAddr().(*net.TCPAddr)
raddr = c.RemoteAddr().(*net.TCPAddr)
var ip [16]byte
copy(ip[:], raddr.IP.To16())
record = pListener.getRecord(ip)
Expand Down Expand Up @@ -204,15 +204,14 @@ func (pListener *PListener) Accept() (conn net.Conn, err error) {
}
record.history = append(record.history, now)
}
pconn := &pConn{TCPConn: tcpConn, listener: pListener}
pconn := &pConn{Conn: c, listener: pListener}
record.activeConns = append(record.activeConns, pconn)
conn = pconn
return
}
}

// Close closes the underlying TCP listener and erases the pointers from created connections.
// To remove internal records from the memory, remove all the pointers pointing to PListener.
// Close closes the underlying listener
func (pListener *PListener) Close() error {
pListener.ipRecordMut.Lock()
defer pListener.ipRecordMut.Unlock()
Expand All @@ -225,7 +224,7 @@ func (pListener *PListener) Close() error {
}
record.mut.Unlock()
}
return pListener.TCPListener.Close()
return pListener.Listener.Close()
}

// SetLimiter overrides the default limiter for listener.
Expand Down Expand Up @@ -254,7 +253,7 @@ func (pListener *PListener) Ban(ip net.IP) {
record.privileged = false
if record.activeConns != nil {
for _, c := range record.activeConns {
_ = c.TCPConn.Close()
_ = c.Conn.Close()
}
}
record.activeConns = nil
Expand All @@ -274,7 +273,7 @@ func (pListener *PListener) TempBan(ip net.IP, until time.Time) {
record.privileged = false
if record.activeConns != nil {
for _, c := range record.activeConns {
_ = c.TCPConn.Close()
_ = c.Conn.Close()
}
}
record.activeConns = []*pConn{}
Expand Down
10 changes: 5 additions & 5 deletions plistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func TestPListener_SetLimiter(t *testing.T) {
for i := 0; i < 2; i++ {
start := time.Now()
count := 0
err = listener.SetDeadline(start.Add(time.Second * 20))
err = listener.Listener.(*net.TCPListener).SetDeadline(start.Add(time.Second * 20))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -247,7 +247,7 @@ func TestPListener_Ban(t *testing.T) {
return
}
}()
err = listener.SetDeadline(time.Now().Add(time.Second * 7))
err = listener.Listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second * 7))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -316,7 +316,7 @@ func TestPListener_TempBan(t *testing.T) {
}
}
}()
err = listener.SetDeadline(time.Now().Add(time.Minute))
err = listener.Listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Minute))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -374,7 +374,7 @@ func TestPListener_GivePrivilege(t *testing.T) {
}()
count := 0
for {
err = listener.SetDeadline(time.Now().Add(time.Second * 7))
err = listener.Listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second * 7))
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -456,7 +456,7 @@ func TestPListener_OnSpam(t *testing.T) {
banned = true
}
for {
err = listener.SetDeadline(time.Now().Add(time.Second * 5))
err = listener.Listener.(*net.TCPListener).SetDeadline(time.Now().Add(time.Second * 5))
if err != nil {
log.Fatal(err)
}
Expand Down

0 comments on commit 0e39bae

Please sign in to comment.