diff --git a/server.go b/server.go index 9da6576..3c30057 100644 --- a/server.go +++ b/server.go @@ -64,6 +64,75 @@ func NewServer(opt ...Option) (*Server, error) { }, nil } +// Index of rightmost occurrence of b in s. +func last(s string, b byte) int { + i := len(s) + for i--; i >= 0; i-- { + if s[i] == b { + break + } + } + return i +} + +// validateAddr will not only validate the addr, but if it's an ipv6 literal without +// proper brackets, it will add them. +func validateAddr(addr string) (string, error) { + const op = "gldap.parseAddr" + + lastColon := last(addr, ':') + if lastColon < 0 { + return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter) + } + rawHost := addr[0:lastColon] + rawPort := addr[lastColon+1:] + if len(rawPort) == 0 { + return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter) + } + if len(rawHost) == 0 { + return fmt.Sprintf(":%s", rawPort), nil + } + // ipv6 literal with proper brackets + if rawHost[0] == '[' { + // Expect the first ']' just before the last ':'. + end := strings.IndexByte(rawHost, ']') + if end < 0 { + return "", fmt.Errorf("%s: missing ']' in ipv6 address %s : %w", op, addr, ErrInvalidParameter) + } + trimedIp := strings.Trim(rawHost, "[]") + if net.ParseIP(trimedIp) == nil { + return "", fmt.Errorf("%s: invalid ipv6 address %s : %w", op, rawHost, ErrInvalidParameter) + } + // ipv6 literal has enclosing brackets, and it's a valid ipv6 address, so we're good + return fmt.Sprintf("%s:%s", rawHost, rawPort), nil + } + + // see if we're dealing with a hostname + hostnames, _ := net.LookupHost(rawHost) + if len(hostnames) > 0 { + if rawHost == "::1" { + // special case for localhost + return fmt.Sprintf("[%s]:%s", rawHost, rawPort), nil + } + return fmt.Sprintf("%s:%s", rawHost, rawPort), nil + } + + lastColon = last(rawHost, ':') + if lastColon >= 0 { + // ipv6 literal without proper brackets + ipv6Literal := fmt.Sprintf("[%s]", rawHost) + if net.ParseIP(ipv6Literal) == nil { + return "", fmt.Errorf("%s: invalid ipv6 address + port %s : %w", op, addr, ErrInvalidParameter) + } + return fmt.Sprintf("[%s]:%s", ipv6Literal, rawPort), nil + } + // ipv4 + if net.ParseIP(rawHost) == nil { + return "", fmt.Errorf("%s: invalid IP address %s : %w", op, rawHost, ErrInvalidParameter) + } + return fmt.Sprintf("%s:%s", rawHost, rawPort), nil +} + // Run will run the server which will listen and serve requests. // // Options supported: WithTLSConfig @@ -72,6 +141,10 @@ func (s *Server) Run(addr string, opt ...Option) error { opts := getConfigOpts(opt...) var err error + addr, err = validateAddr(addr) + if err != nil { + return fmt.Errorf("%s: %w", op, err) + } s.mu.Lock() s.listener, err = net.Listen("tcp", addr) s.listenerReady = true diff --git a/server_internal_test.go b/server_internal_test.go index 328361f..e038247 100644 --- a/server_internal_test.go +++ b/server_internal_test.go @@ -126,3 +126,100 @@ type mockListener struct { func (*mockListener) Close() error { return errors.New("mockListener.Close error") } + +func TestValidateAddr(t *testing.T) { + tests := []struct { + name string + addr string + expected string + wantErrContains string + wantErrIs error + }{ + { + name: "valid-IPv4-address", + addr: "127.0.0.1:389", + expected: "127.0.0.1:389", + }, + { + name: "valid-IPv6-address", + addr: "[::1]:389", + expected: "[::1]:389", + }, + { + name: "valid-IPv6", + addr: "2001:db8:3333:4444:5555:6666:7777:8888:389", + expected: "2001:db8:3333:4444:5555:6666:7777:8888:389", + }, + { + name: "valid-IPv6-localhost-without-brackets", + addr: "::1:389", + expected: "[::1]:389", + }, + { + name: "valid-hostname", + addr: "localhost:389", + expected: "localhost:389", + }, + { + name: "err-missing-port-final-colon", + addr: "198.165.1.1:", + wantErrContains: "missing port in addr", + wantErrIs: ErrInvalidParameter, + }, + { + name: "missing-port -pv4", + addr: "127.0.0.1", + wantErrContains: "missing port in addr", + wantErrIs: ErrInvalidParameter, + }, + { + name: "err-missing-port-ipv6", + addr: "[::1]", + wantErrContains: "missing ']' in ipv6 address [::1]", + wantErrIs: ErrInvalidParameter, + }, + { + name: "err-invalid-IPv4-address", + addr: "0.0", + wantErrContains: "missing port in addr 0.0", + wantErrIs: ErrInvalidParameter, + }, + { + name: "err-invalid-IPv6-address-missing-bracket", + addr: "[::1", + wantErrContains: "missing ']' in ipv6 address [::1", + wantErrIs: ErrInvalidParameter, + }, + { + name: "err-invalid-IPv6", + addr: "2001:db8:3333:4444:5555:6666:7777:389", + wantErrContains: "invalid ipv6 address + port 2001:db8:3333:4444:5555:6666:7777:389", + wantErrIs: ErrInvalidParameter, + }, + { + name: "err-missing-port", + addr: "invalid", + expected: "", + wantErrContains: "missing port in addr invalid", + wantErrIs: ErrInvalidParameter, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + result, err := validateAddr(tc.addr) + if tc.wantErrContains != "" { + require.Error(t, err) + assert.Empty(t, result) + assert.Contains(t, err.Error(), tc.wantErrContains) + if tc.wantErrIs != nil { + assert.ErrorIs(t, err, tc.wantErrIs) + } + return + } + require.NoError(t, err) + assert.Equal(t, tc.expected, result) + }) + } +}