Skip to content

Commit

Permalink
Use network protocol default ttl instead of a hardcoded one.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 423886874
  • Loading branch information
milantracy authored and gvisor-bot committed Jan 24, 2022
1 parent 73f3395 commit 2b2f9ea
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
23 changes: 17 additions & 6 deletions pkg/sentry/socket/netstack/netstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,15 @@ func getSockOptICMPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return nil, syserr.ErrProtocolNotAvailable
}

func defaultTTL(t *kernel.Task, network tcpip.NetworkProtocolNumber) (primitive.Int32, tcpip.Error) {
var opt tcpip.DefaultTTLOption
stack := inet.StackFromContext(t)
if err := stack.(*Stack).Stack.NetworkProtocolOption(network, &opt); err != nil {
return 0, err
}
return primitive.Int32(opt), nil
}

// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if _, ok := ep.(tcpip.Endpoint); !ok {
Expand Down Expand Up @@ -1377,9 +1386,10 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
// Fill in the default value, if needed.
vP := primitive.Int32(v)
if vP == -1 {
// TODO(https://github.com/google/gvisor/issues/6973): Retrieve the
// configured DefaultTTLOption of the IPv6 protocol.
vP = DefaultTTL
vP, err = defaultTTL(t, header.IPv6ProtocolNumber)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}

return &vP, nil
Expand Down Expand Up @@ -1540,9 +1550,10 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
// Fill in the default value, if needed.
vP := primitive.Int32(v)
if vP == 0 {
// TODO(https://github.com/google/gvisor/issues/6973): Retrieve the
// configured DefaultTTLOption of the IPv4 protocol.
vP = DefaultTTL
vP, err = defaultTTL(t, header.IPv4ProtocolNumber)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}

return &vP, nil
Expand Down
1 change: 1 addition & 0 deletions test/syscalls/linux/socket_ip_unbound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ TEST_P(IPUnboundSocketTest, ResetTtlToDefault) {
EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get2_sz, sizeof(get2));
EXPECT_TRUE(get2 == 64 || get2 == 127);
EXPECT_EQ(get2, get1);
}

Expand Down
8 changes: 7 additions & 1 deletion test/syscalls/linux/socket_ipv6_unbound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace testing {
namespace {

constexpr int kDefaultHopLimit = 64;
constexpr int kDefaultTtl = 64;

using ::testing::ValuesIn;
using IPv6UnboundSocketTest = SimpleSocketTest;
Expand All @@ -37,13 +38,18 @@ TEST_P(IPv6UnboundSocketTest, HopLimitDefault) {
std::unique_ptr<FileDescriptor> socket =
ASSERT_NO_ERRNO_AND_VALUE(NewSocket());

const int set = -1;
ASSERT_THAT(setsockopt(socket->get(), IPPROTO_IPV6, IPV6_UNICAST_HOPS, &set,
sizeof(set)),
SyscallSucceedsWithValue(0));

int get = -1;
socklen_t get_sz = sizeof(get);
ASSERT_THAT(
getsockopt(socket->get(), IPPROTO_IPV6, IPV6_UNICAST_HOPS, &get, &get_sz),
SyscallSucceedsWithValue(0));
ASSERT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, kDefaultHopLimit);
EXPECT_EQ(get, kDefaultTtl);
}

TEST_P(IPv6UnboundSocketTest, SetHopLimit) {
Expand Down

0 comments on commit 2b2f9ea

Please sign in to comment.