Skip to content

Commit

Permalink
Add fallback-able support
Browse files Browse the repository at this point in the history
  • Loading branch information
blacktear23 committed Feb 14, 2023
1 parent 7a56c92 commit d98ad85
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 39 deletions.
19 changes: 17 additions & 2 deletions README.md
Expand Up @@ -21,7 +21,7 @@ basic usage
l, err := net.Listen("tcp", "...")

// Wrap listener as PROXY protocol listener
ppl, err := proxyprotocol.NewListener(l, "*", 5)
ppl, err := proxyprotocol.NewListener(l, "*", 5, false)

for {
conn, err := ppl.Accept()
Expand Down Expand Up @@ -56,7 +56,22 @@ l, err := net.Listener("tcp", "...")


// Wrap listener as PROXY protocol listener and enable lazy mode.
ppl, err := proxyprotocol.NewLazyListener(l, "*", 5)
ppl, err := proxyprotocol.NewLazyListener(l, "*", 5, false)

...
```

## Fallback-able

`go-proxyprotocol` support fallback-able mode for ProxyProtocol header process. When multiple client with different system connect to the server and some using PROXY protocol some not and it's hard to determine the allowed IP range, just set `fallbackable` parameter to `true`, it can handle this.

```go
// Create listener
l, err := net.Listener("tcp", "...")


// Wrap listener as PROXY protocol listener and enable lazy mode and fallback-able
ppl, err := proxyprotocol.NewLazyListener(l, "*", 5, true)

...
```
67 changes: 42 additions & 25 deletions proxy_protocol.go
Expand Up @@ -28,8 +28,10 @@ const (
var (
ErrProxyProtocolV1HeaderInvalid = errors.New("PROXY Protocol v1 header is invalid")
ErrProxyProtocolV2HeaderInvalid = errors.New("PROXY Protocol v2 header is invalid")
ErrProxyProtocolInvalid = errors.New("Invalid PROXY Protocol Header")
ErrHeaderReadTimeout = errors.New("Header read timeout")
proxyProtocolV2Sig = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
proxyProtocolV1Sig = []byte("PROXY")

_ net.Conn = &proxyProtocolConn{}
_ net.Listener = &proxyProtocolListener{}
Expand All @@ -42,26 +44,28 @@ type connErr struct {

type proxyProtocolListener struct {
listener net.Listener
allowAll bool
lazyMode bool
allowedNets []*net.IPNet
runningFlag int32
headerReadTimeout int // Unit is second
acceptQueue chan *connErr
runningFlag int32
allowAll bool
lazyMode bool
fallbackable bool
}

func IsProxyProtocolError(err error) bool {
return err == ErrProxyProtocolV1HeaderInvalid ||
err == ErrProxyProtocolV2HeaderInvalid ||
err == ErrHeaderReadTimeout
err == ErrHeaderReadTimeout ||
err == ErrProxyProtocolInvalid
}

// Create new PROXY protocol listener
// * listener is basic listener for TCP
// * allowedIPs is protocol allowed addresses or CIDRs split by `,` if use '*' means allow any address
// * headerReadTimeout is timeout for PROXY protocol header read
func NewListener(listener net.Listener, allowedIPs string, headerReadTimeout int) (net.Listener, error) {
ppl, err := newListener(listener, allowedIPs, headerReadTimeout, false)
func NewListener(listener net.Listener, allowedIPs string, headerReadTimeout int, fallbackable bool) (net.Listener, error) {
ppl, err := newListener(listener, allowedIPs, headerReadTimeout, false, fallbackable)
if err == nil {
go ppl.acceptLoop()
}
Expand All @@ -73,15 +77,15 @@ func NewListener(listener net.Listener, allowedIPs string, headerReadTimeout int
// * listener is basic listener for TCP
// * allowedIPs is protocol allowed addresses or CIDRs split by `,` if use '*' means allow any address
// * headerReadTimeout is timeout for PROXY protocol header read
func NewLazyListener(listener net.Listener, allowedIPs string, headerReadTimeout int) (net.Listener, error) {
ppl, err := newListener(listener, allowedIPs, headerReadTimeout, true)
func NewLazyListener(listener net.Listener, allowedIPs string, headerReadTimeout int, fallbackable bool) (net.Listener, error) {
ppl, err := newListener(listener, allowedIPs, headerReadTimeout, true, fallbackable)
if err == nil {
go ppl.acceptLoop()
}
return ppl, err
}

func newListener(listener net.Listener, allowedIPs string, headerReadTimeout int, lazyMode bool) (*proxyProtocolListener, error) {
func newListener(listener net.Listener, allowedIPs string, headerReadTimeout int, lazyMode bool, fallbackable bool) (*proxyProtocolListener, error) {
allowAll := false
allowedNets := []*net.IPNet{}
if allowedIPs == "*" {
Expand Down Expand Up @@ -109,6 +113,7 @@ func newListener(listener net.Listener, allowedIPs string, headerReadTimeout int
acceptQueue: make(chan *connErr, 1),
runningFlag: 1,
lazyMode: lazyMode,
fallbackable: fallbackable,
}, nil
}

Expand Down Expand Up @@ -137,6 +142,7 @@ func (l *proxyProtocolListener) createProxyProtocolConn(conn net.Conn) (*proxyPr
headerReadTimeout: l.headerReadTimeout,
lazyMode: l.lazyMode,
headerReaded: false,
fallbackable: l.fallbackable,
}
if !l.lazyMode {
err := ppconn.readClientAddrBehindProxy(conn.RemoteAddr())
Expand Down Expand Up @@ -215,6 +221,7 @@ type proxyProtocolConn struct {
exceedBufferReaded bool
lazyMode bool
headerReaded bool
fallbackable bool
}

func (c *proxyProtocolConn) readClientAddrBehindProxy(connRemoteAddr net.Addr) error {
Expand Down Expand Up @@ -243,6 +250,14 @@ func (c *proxyProtocolConn) parseHeader(connRemoteAddr net.Addr) error {
c.clientIP = raddr
c.headerReaded = true
return nil
case unknownProtocol:
if c.fallbackable {
c.exceedBuffer = buffer
c.exceedBufferLen = len(buffer)
c.headerReaded = true
return nil
}
return ErrProxyProtocolInvalid
default:
panic("Should not come here")
}
Expand Down Expand Up @@ -374,6 +389,7 @@ func (c *proxyProtocolConn) readHeader() (int, []byte, error) {
return unknownProtocol, nil, ErrHeaderReadTimeout
}
if n >= 16 {
// Chech Proxy Protocol V2 header
if bytes.Equal(buf[0:12], proxyProtocolV2Sig) && (buf[v2CmdPos]&0xF0) == 0x20 {
endPos := 16 + int(binary.BigEndian.Uint16(buf[v2LenPos:v2LenPos+2]))
if n < endPos {
Expand All @@ -400,22 +416,23 @@ func (c *proxyProtocolConn) readHeader() (int, []byte, error) {
}
}
if n >= 5 {
if string(buf[0:5]) != "PROXY" {
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}
pos := bytes.IndexByte(buf, byte(10))
if pos == -1 {
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}
if buf[pos-1] != byte(13) {
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}
endPos := pos
if n > endPos {
c.exceedBuffer = buf[endPos+1:]
c.exceedBufferLen = n - endPos
// Chech Proxy Protocol V1 header
if bytes.Equal(buf[0:5], proxyProtocolV1Sig) {
pos := bytes.IndexByte(buf, byte(10))
if pos == -1 {
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}
if buf[pos-1] != byte(13) {
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
}
endPos := pos
if n > endPos {
c.exceedBuffer = buf[endPos+1:]
c.exceedBufferLen = n - endPos
}
return proxyProtocolV1, buf[0 : endPos+1], nil
}
return proxyProtocolV1, buf[0 : endPos+1], nil
}
return unknownProtocol, nil, ErrProxyProtocolV1HeaderInvalid
// Unknown protocol
return unknownProtocol, buf[0:n], nil
}
86 changes: 74 additions & 12 deletions proxy_protocol_test.go
Expand Up @@ -89,10 +89,10 @@ func assertEquals[T comparable](t *testing.T, val, expected T, comments ...any)
}

func TestProxyProtocolConnCheckAllowed(t *testing.T) {
l, _ := newListener(nil, "*", 5, false)
l, _ := newListener(nil, "*", 5, false, false)
raddr, _ := net.ResolveTCPAddr("tcp4", "192.168.1.100:8080")
assertTrue(t, l.checkAllowed(raddr))
l, _ = newListener(nil, "192.168.1.0/24,192.168.2.0/24", 5, false)
l, _ = newListener(nil, "192.168.1.0/24,192.168.2.0/24", 5, false, false)
for _, ipstr := range []string{"192.168.1.100:8080", "192.168.2.100:8080"} {
raddr, _ := net.ResolveTCPAddr("tcp4", ipstr)
assertTrue(t, l.checkAllowed(raddr))
Expand All @@ -107,7 +107,7 @@ func TestProxyProtocolConnMustNotReadAnyDataAfterCLRF(t *testing.T) {
buffer := []byte("PROXY TCP4 192.168.1.100 192.168.1.50 5678 3306\r\nOther Data")
conn := newMockBufferConn(bytes.NewBuffer(buffer), nil)

l, _ := newListener(nil, "*", 5, false)
l, _ := newListener(nil, "*", 5, false, false)
wconn, err := l.createProxyProtocolConn(conn)
assertNil(t, err)

Expand Down Expand Up @@ -147,7 +147,7 @@ func TestProxyProtocolV2ConnMustNotReadAnyDataAfterHeader(t *testing.T) {
buffer := encodeProxyProtocolV2Header("tcp4", "192.168.1.100:5678", "192.168.1.5:4000")
expectedString := "Other Data"
buffer = append(buffer, []byte(expectedString)...)
l, _ := newListener(nil, "*", 5, false)
l, _ := newListener(nil, "*", 5, false, false)
conn := newMockBufferConn(bytes.NewBuffer(buffer), craddr)
wconn, err := l.createProxyProtocolConn(conn)
buf := make([]byte, len(expectedString))
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestProxyProtocolV2ConnMustNotReadAnyDataAfterHeaderAndTlvs(t *testing.T) {
for _, test := range tests {
buffer := test.buffer
buffer = append(buffer, []byte(test.expect)...)
l, _ := newListener(nil, "*", 5, false)
l, _ := newListener(nil, "*", 5, false, false)
conn := newMockBufferConn(bytes.NewBuffer(buffer), craddr)
wconn, err := l.createProxyProtocolConn(conn)
buf := make([]byte, len(test.expect))
Expand Down Expand Up @@ -265,7 +265,7 @@ func TestProxyProtocolV1ExtractClientIP(t *testing.T) {
},
}

l, _ := newListener(nil, "*", 5, false)
l, _ := newListener(nil, "*", 5, false, false)
for _, test := range tests {
conn := newMockBufferConn(bytes.NewBuffer(test.buffer), craddr)
wconn, err := l.createProxyProtocolConn(conn)
Expand Down Expand Up @@ -379,7 +379,7 @@ func TestProxyProtocolV2HeaderRead(t *testing.T) {
},
}

l, _ := newListener(nil, "*", 5, false)
l, _ := newListener(nil, "*", 5, false, false)
for _, test := range tests {
conn := newMockBufferConn(bytes.NewBuffer(test.buffer), craddr)
wconn, err := l.createProxyProtocolConn(conn)
Expand All @@ -400,7 +400,7 @@ func TestProxyProtocolV2HeaderReadLocalCommand(t *testing.T) {
craddr, _ := net.ResolveTCPAddr("tcp4", "192.168.1.51:8080")
buffer := encodeProxyProtocolV2Header("tcp4", "192.168.1.100:5678", "192.168.1.5:4000")
buffer[v2CmdPos] = 0x20
l, _ := newListener(nil, "*", 5, false)
l, _ := newListener(nil, "*", 5, false, false)
conn := newMockBufferConn(bytes.NewBuffer(buffer), craddr)
wconn, err := l.createProxyProtocolConn(conn)
clientIP := wconn.RemoteAddr()
Expand All @@ -415,7 +415,7 @@ func TestProxyProtocolListenerReadHeaderTimeout(t *testing.T) {
go func() {
l, err := net.Listen("tcp", addr)
assertNil(t, err)
ppl, err := NewListener(l, "*", 1)
ppl, err := NewListener(l, "*", 1, false)
assertNil(t, err)
defer ppl.Close()
wg.Done()
Expand All @@ -438,7 +438,7 @@ func TestProxyProtocolListenerProxyNotAllowed(t *testing.T) {
go func() {
l, err := net.Listen("tcp", addr)
assertNil(t, err)
ppl, err := NewListener(l, "192.168.1.1", 1)
ppl, err := NewListener(l, "192.168.1.1", 1, false)
assertNil(t, err)
defer ppl.Close()
wg.Done()
Expand All @@ -459,7 +459,7 @@ func TestProxyProtocolListenerCloseInOtherGoroutine(t *testing.T) {
addr := "127.0.0.1:18082"
l, err := net.Listen("tcp", addr)
assertNil(t, err)
ppl, err := NewListener(l, "*", 1)
ppl, err := NewListener(l, "*", 1, false)
assertNil(t, err)
go func() {
conn, err := ppl.Accept()
Expand Down Expand Up @@ -516,7 +516,7 @@ func TestProxyProtocolLazyMode(t *testing.T) {
expectErr: true,
},
}
l, _ := newListener(nil, "*", 5, true)
l, _ := newListener(nil, "*", 5, true, false)
for _, test := range tests {
buffer := test.buffer
buffer = append(buffer, []byte(test.expectData)...)
Expand All @@ -539,3 +539,65 @@ func TestProxyProtocolLazyMode(t *testing.T) {
}
}
}

func TestProxyProtocolLazyModeFallback(t *testing.T) {
tlvData1 := append([]byte{0xE3, 0x00, 0x01}, make([]byte, 100)...)
craddr, _ := net.ResolveTCPAddr("tcp4", "192.168.1.51:8080")
tests := []struct {
buffer []byte
expectData string
expectIP string
expectErr bool
}{
{
buffer: []byte("Raw Connection Other Data"),
expectData: "Raw Connection Other Data",
expectIP: "192.168.1.51:8080",
expectErr: false,
},
{
buffer: append(encodeProxyProtocolV2HeaderAndTlv("tcp4", "192.168.1.100:5678", "192.168.1.5:4000", tlvData1), []byte("Other Data")...),
expectData: "Other Data",
expectIP: "192.168.1.100:5678",
expectErr: false,
},
{
buffer: append(encodeProxyProtocolV2Header("tcp4", "192.168.1.100:5678", "192.168.1.5:4000"), []byte("Other Data")...),
expectData: "Other Data",
expectIP: "192.168.1.100:5678",
expectErr: false,
},
{
buffer: []byte("PROXY MCP3 192.168.1.100 192.168.1.50 5678 3306\r\nOther Data"),
expectData: "Other Data",
expectIP: "",
expectErr: true,
},
{
buffer: []byte("Some bad data"),
expectData: "Some bad data",
expectIP: "192.168.1.51:8080",
expectErr: false,
},
}
l, _ := newListener(nil, "*", 5, true, true)
for _, test := range tests {
buffer := test.buffer
conn := newMockBufferConn(bytes.NewBuffer(buffer), craddr)
wconn, err := l.createProxyProtocolConn(conn)
clientIP := wconn.RemoteAddr()
assertEquals(t, clientIP.String(), craddr.String(), "Buffer:%s\nExpect: %s Got: %s", string(buffer), craddr.String(), clientIP.String())
buf := make([]byte, len(test.expectData))
n, err := wconn.Read(buf)
if test.expectErr {
if err == nil {
t.Errorf("Buffer: %s\nExpect Error", string(buffer))
}
} else {
assertNil(t, err)
assertEquals(t, string(buf[0:n]), test.expectData)
clientIP = wconn.RemoteAddr()
assertEquals(t, clientIP.String(), test.expectIP, "Buffer:%s\nExpect: %s Got: %s", string(buffer), test.expectIP, clientIP.String())
}
}
}

0 comments on commit d98ad85

Please sign in to comment.