Skip to content

Commit

Permalink
feat: Server.Run(...) validate addr
Browse files Browse the repository at this point in the history
Add validation to the addr parameter of
Server.Run(...) to ensure it is a valid TCP
address.   If the addr is ipv6, and it's missing
the required square brackets, add them.
  • Loading branch information
jimlambrt committed Feb 21, 2024
1 parent be86831 commit 8f12112
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 0 deletions.
73 changes: 73 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,75 @@ func NewServer(opt ...Option) (*Server, error) {
}, nil
}

// Index of rightmost occurrence of b in s.
func last(s string, b byte) int {
i := len(s)
for i--; i >= 0; i-- {
if s[i] == b {
break
}
}
return i
}

// validateAddr will not only validate the addr, but if it's an ipv6 literal without
// proper brackets, it will add them.
func validateAddr(addr string) (string, error) {
const op = "gldap.parseAddr"

lastColon := last(addr, ':')
if lastColon < 0 {
return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter)
}
rawHost := addr[0:lastColon]
rawPort := addr[lastColon+1:]
if len(rawPort) == 0 {
return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter)
}
if len(rawHost) == 0 {
return fmt.Sprintf(":%s", rawPort), nil
}
// ipv6 literal with proper brackets
if rawHost[0] == '[' {
// Expect the first ']' just before the last ':'.
end := strings.IndexByte(rawHost, ']')
if end < 0 {
return "", fmt.Errorf("%s: missing ']' in ipv6 address %s : %w", op, addr, ErrInvalidParameter)
}
trimedIp := strings.Trim(rawHost, "[]")
if net.ParseIP(trimedIp) == nil {
return "", fmt.Errorf("%s: invalid ipv6 address %s : %w", op, rawHost, ErrInvalidParameter)
}
// ipv6 literal has enclosing brackets, and it's a valid ipv6 address, so we're good
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
}

// see if we're dealing with a hostname
hostnames, _ := net.LookupHost(rawHost)
if len(hostnames) > 0 {
if rawHost == "::1" {
// special case for localhost
return fmt.Sprintf("[%s]:%s", rawHost, rawPort), nil
}
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
}

lastColon = last(rawHost, ':')
if lastColon >= 0 {
// ipv6 literal without proper brackets
ipv6Literal := fmt.Sprintf("[%s]", rawHost)
if net.ParseIP(ipv6Literal) == nil {
return "", fmt.Errorf("%s: invalid ipv6 address + port %s : %w", op, addr, ErrInvalidParameter)
}
return fmt.Sprintf("[%s]:%s", ipv6Literal, rawPort), nil
}
// ipv4
if net.ParseIP(rawHost) == nil {
return "", fmt.Errorf("%s: invalid IP address %s : %w", op, rawHost, ErrInvalidParameter)
}
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
}

// Run will run the server which will listen and serve requests.
//
// Options supported: WithTLSConfig
Expand All @@ -72,6 +141,10 @@ func (s *Server) Run(addr string, opt ...Option) error {
opts := getConfigOpts(opt...)

var err error
addr, err = validateAddr(addr)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
s.mu.Lock()
s.listener, err = net.Listen("tcp", addr)
s.listenerReady = true
Expand Down
97 changes: 97 additions & 0 deletions server_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,100 @@ type mockListener struct {
func (*mockListener) Close() error {
return errors.New("mockListener.Close error")
}

func TestValidateAddr(t *testing.T) {
tests := []struct {
name string
addr string
expected string
wantErrContains string
wantErrIs error
}{
{
name: "valid-IPv4-address",
addr: "127.0.0.1:389",
expected: "127.0.0.1:389",
},
{
name: "valid-IPv6-address",
addr: "[::1]:389",
expected: "[::1]:389",
},
{
name: "valid-IPv6",
addr: "2001:db8:3333:4444:5555:6666:7777:8888:389",
expected: "2001:db8:3333:4444:5555:6666:7777:8888:389",
},
{
name: "valid-IPv6-localhost-without-brackets",
addr: "::1:389",
expected: "[::1]:389",
},
{
name: "valid-hostname",
addr: "localhost:389",
expected: "localhost:389",
},
{
name: "err-missing-port-final-colon",
addr: "198.165.1.1:",
wantErrContains: "missing port in addr",
wantErrIs: ErrInvalidParameter,
},
{
name: "missing-port -pv4",
addr: "127.0.0.1",
wantErrContains: "missing port in addr",
wantErrIs: ErrInvalidParameter,
},
{
name: "err-missing-port-ipv6",
addr: "[::1]",
wantErrContains: "missing ']' in ipv6 address [::1]",
wantErrIs: ErrInvalidParameter,
},
{
name: "err-invalid-IPv4-address",
addr: "0.0",
wantErrContains: "missing port in addr 0.0",
wantErrIs: ErrInvalidParameter,
},
{
name: "err-invalid-IPv6-address-missing-bracket",
addr: "[::1",
wantErrContains: "missing ']' in ipv6 address [::1",
wantErrIs: ErrInvalidParameter,
},
{
name: "err-invalid-IPv6",
addr: "2001:db8:3333:4444:5555:6666:7777:389",
wantErrContains: "invalid ipv6 address + port 2001:db8:3333:4444:5555:6666:7777:389",
wantErrIs: ErrInvalidParameter,
},
{
name: "err-missing-port",
addr: "invalid",
expected: "",
wantErrContains: "missing port in addr invalid",
wantErrIs: ErrInvalidParameter,
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
result, err := validateAddr(tc.addr)
if tc.wantErrContains != "" {
require.Error(t, err)
assert.Empty(t, result)
assert.Contains(t, err.Error(), tc.wantErrContains)
if tc.wantErrIs != nil {
assert.ErrorIs(t, err, tc.wantErrIs)
}
return
}
require.NoError(t, err)
assert.Equal(t, tc.expected, result)
})
}
}

0 comments on commit 8f12112

Please sign in to comment.