Skip to content

Commit

Permalink
fix: fix how validateAddrPort(...) handles ipv6 literals
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt committed Mar 6, 2024
1 parent bcbc309 commit 9fa194c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 30 deletions.
43 changes: 23 additions & 20 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/tls"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -75,35 +76,37 @@ func last(s string, b byte) int {
return i
}

// validateAddr will not only validate the addr, but if it's an ipv6 literal without
// proper brackets, it will add them.
func validateAddr(addr string) (string, error) {
// validateAddrPort will not only validate the address+port, but if it's an ipv6
// literal without proper brackets, it will add them.
func validateAddrPort(addrPort string) (string, error) {
const op = "gldap.parseAddr"

lastColon := last(addr, ':')
lastColon := last(addrPort, ':')
if lastColon < 0 {
return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter)
return "", fmt.Errorf("%s: missing port in addr \"%s\": %w", op, addrPort, ErrInvalidParameter)
}
rawHost := addr[0:lastColon]
rawPort := addr[lastColon+1:]
rawHost := addrPort[0:lastColon]
rawPort := addrPort[lastColon+1:]
switch {
case len(rawPort) == 0:
return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter)
return "", fmt.Errorf("%s: missing port in addr \"%s\": %w", op, addrPort, ErrInvalidParameter)
case len(rawHost) == 0:
return fmt.Sprintf(":%s", rawPort), nil
case addr[0] == '[' && addr[len(addr)-1] == ']':
return "", fmt.Errorf("%s: missing port in ipv6 addr : %s : %w", op, addr, ErrInvalidParameter)
case addrPort[0] == '[' && addrPort[len(addrPort)-1] == ']':
return "", fmt.Errorf("%s: missing port in ipv6 addr : \"%s\": %w", op, addrPort, ErrInvalidParameter)
}
// ipv6 literal with proper brackets
if rawHost[0] == '[' {
// Expect the first ']' just before the last ':'.
end := strings.IndexByte(rawHost, ']')
if end < 0 {
return "", fmt.Errorf("%s: missing ']' in ipv6 address %s : %w", op, addr, ErrInvalidParameter)
return "", fmt.Errorf("%s: missing ']' in ipv6 address \"%s\": %w", op, addrPort, ErrInvalidParameter)
}
// Note: netip.ParseAddr requires ipv6 addresses without brackets []
trimmedIp := strings.Trim(rawHost, "[]")
if net.ParseIP(trimmedIp) == nil {
return "", fmt.Errorf("%s: invalid ipv6 address %s : %w", op, rawHost, ErrInvalidParameter)
if _, err := netip.ParseAddr(trimmedIp); err != nil {
// if net.ParseIP(trimmedIp) == nil {
return "", fmt.Errorf("%s: invalid ipv6 address \"%s\": %w", op, rawHost, err)
}
// ipv6 literal has enclosing brackets, and it's a valid ipv6 address, so we're good
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
Expand All @@ -123,16 +126,16 @@ func validateAddr(addr string) (string, error) {

lastColon = last(rawHost, ':')
if lastColon >= 0 {
// ipv6 literal without proper brackets
ipv6Literal := fmt.Sprintf("[%s]", rawHost)
if net.ParseIP(ipv6Literal) == nil {
return "", fmt.Errorf("%s: invalid ipv6 address + port %s : %w", op, addr, ErrInvalidParameter)
// ipv6 literal without proper brackets. Note: netip.ParseAddr requires
// ipv6 addresses without brackets []
if _, err := netip.ParseAddr(rawHost); err != nil {
return "", fmt.Errorf("%s: invalid ipv6 address + port \"%s\": %w", op, addrPort, err)
}
return fmt.Sprintf("[%s]:%s", ipv6Literal, rawPort), nil
return fmt.Sprintf("[%s]:%s", rawHost, rawPort), nil
}
// ipv4
if net.ParseIP(rawHost) == nil {
return "", fmt.Errorf("%s: invalid IP address %s : %w", op, rawHost, ErrInvalidParameter)
return "", fmt.Errorf("%s: invalid IP address \"%s\": %w", op, rawHost, ErrInvalidParameter)
}
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
}
Expand All @@ -145,7 +148,7 @@ func (s *Server) Run(addr string, opt ...Option) error {
opts := getConfigOpts(opt...)

var err error
addr, err = validateAddr(addr)
addr, err = validateAddrPort(addr)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
Expand Down
18 changes: 8 additions & 10 deletions server_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (*mockListener) Close() error {
return errors.New("mockListener.Close error")
}

func TestValidateAddr(t *testing.T) {
func Test_validateAddrPort(t *testing.T) {
tests := []struct {
name string
addr string
Expand Down Expand Up @@ -180,46 +180,44 @@ func TestValidateAddr(t *testing.T) {
{
name: "err-missing-port-ipv6",
addr: "[::1]",
wantErrContains: "missing port in ipv6 addr : [::1]",
wantErrContains: "missing port in ipv6 addr : \"[::1]\"",
wantErrIs: ErrInvalidParameter,
},
{
name: "err-invalid-IPv4-address",
addr: "0.0",
wantErrContains: "missing port in addr 0.0",
wantErrContains: "missing port in addr \"0.0\"",
wantErrIs: ErrInvalidParameter,
},
{
name: "err-invalid-IPv6-address-missing-start-bracket",
addr: "::1]",
wantErrContains: "invalid ipv6 address + port ::1]",
wantErrIs: ErrInvalidParameter,
wantErrContains: "invalid ipv6 address + port \"::1]\": ParseAddr(\":\"): each colon-separated field must have at least one digit (at \":\")",
},
{
name: "err-invalid-IPv6-address-missing-final-bracket",
addr: "[::1",
wantErrContains: "missing ']' in ipv6 address [::1",
wantErrContains: "missing ']' in ipv6 address \"[::1\"",
wantErrIs: ErrInvalidParameter,
},
{
name: "err-invalid-IPv6",
addr: "2001:db8:3333:4444:5555:6666:7777:389",
wantErrContains: "invalid ipv6 address + port 2001:db8:3333:4444:5555:6666:7777:389",
wantErrIs: ErrInvalidParameter,
wantErrContains: "invalid ipv6 address + port \"2001:db8:3333:4444:5555:6666:7777:389\": ParseAddr(\"2001:db8:3333:4444:5555:6666:7777\"): address string too short",
},
{
name: "err-missing-port",
addr: "invalid",
expected: "",
wantErrContains: "missing port in addr invalid",
wantErrContains: "missing port in addr \"invalid\"",
wantErrIs: ErrInvalidParameter,
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
result, err := validateAddr(tc.addr)
result, err := validateAddrPort(tc.addr)
if tc.wantErrContains != "" {
require.Error(t, err)
assert.Empty(t, result)
Expand Down

0 comments on commit 9fa194c

Please sign in to comment.