Skip to content

Commit

Permalink
netstack: refactor IPv4 source address selection
Browse files Browse the repository at this point in the history
This was in the wrong place. This CL just refactors it into package ipv4, just
as IPv6-specific logic is in package ipv6.

PiperOrigin-RevId: 558922801
  • Loading branch information
kevinGC authored and gvisor-bot committed Aug 21, 2023
1 parent 848ec33 commit baf097a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 29 deletions.
18 changes: 17 additions & 1 deletion pkg/tcpip/network/ipv4/ipv4.go
Expand Up @@ -1424,7 +1424,23 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow
//
// +checklocksread:e.mu
func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
return e.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
if remoteAddr.BitLen() == 0 {
return e.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}

var best stack.AddressEndpoint
var bestLen uint8
e.addressableEndpointState.ForEachPrimaryEndpoint(func(ep stack.AddressEndpoint) bool {
if matchLen := ep.AddressWithPrefix().Address.MatchingPrefix(remoteAddr); best == nil || bestLen < matchLen {
best = ep
bestLen = matchLen
}
return true
})
if best != nil {
best.IncRef()
}
return best
}

// PrimaryAddresses implements stack.AddressableEndpoint.
Expand Down
30 changes: 3 additions & 27 deletions pkg/tcpip/stack/addressable_endpoint_state.go
Expand Up @@ -18,7 +18,6 @@ import (
"fmt"

"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)

func (lifetimes *AddressLifetimes) sanitize() {
Expand Down Expand Up @@ -434,7 +433,7 @@ func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
a.mu.RLock()
defer a.mu.RUnlock()

ep := a.acquirePrimaryAddressRLocked(tcpip.Address{}, func(ep *addressState) bool {
ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
switch kind := ep.GetKind(); kind {
case Permanent:
return a.networkEndpoint.Enabled() || !a.options.HiddenWhileDisabled
Expand Down Expand Up @@ -462,30 +461,7 @@ func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
// valid according to isValid.
//
// +checklocksread:a.mu
func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(remoteAddr tcpip.Address, isValid func(*addressState) bool) *addressState {
// TODO: Move this out into IPv4-specific code.
// IPv6 handles source IP selection elsewhere. We have to do source
// selection only for IPv4, in which case ep is never deprecated. Thus
// we don't have to worry about refcounts.
if remoteAddr.Len() == header.IPv4AddressSize && remoteAddr != (tcpip.Address{}) {
var best *addressState
var bestLen uint8
for _, state := range a.primary {
if !isValid(state) {
continue
}
stateLen := state.addr.Address.MatchingPrefix(remoteAddr)
if best == nil || bestLen < stateLen {
best = state
bestLen = stateLen
}
}
if best != nil {
best.IncRef()
}
return best
}

func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState {
var deprecatedEndpoint *addressState
for _, ep := range a.primary {
if !isValid(ep) {
Expand Down Expand Up @@ -623,7 +599,7 @@ func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpi
a.mu.Lock()
defer a.mu.Unlock()

ep := a.acquirePrimaryAddressRLocked(remoteAddr, func(ep *addressState) bool {
ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
return ep.IsAssigned(allowExpired)
})

Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/stack/stack_test.go
Expand Up @@ -1310,7 +1310,7 @@ func TestRoutes(t *testing.T) {
testRoute(t, s, 1, tcpip.AddrFromSlice([]byte("\x03\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x05\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x03\x00\x00\x00")))

// Test routes to even address.
testRoute(t, s, 0, tcpip.Address{}, tcpip.AddrFromSlice([]byte("\x06\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x04\x00\x00\x00")))
testRoute(t, s, 0, tcpip.Address{}, tcpip.AddrFromSlice([]byte("\x06\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x02\x00\x00\x00")))
testRoute(t, s, 0, tcpip.AddrFromSlice([]byte("\x02\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x06\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x02\x00\x00\x00")))
testRoute(t, s, 2, tcpip.AddrFromSlice([]byte("\x02\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x06\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x02\x00\x00\x00")))
testRoute(t, s, 0, tcpip.AddrFromSlice([]byte("\x04\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x06\x00\x00\x00")), tcpip.AddrFromSlice([]byte("\x04\x00\x00\x00")))
Expand Down

0 comments on commit baf097a

Please sign in to comment.