Skip to content

Commit

Permalink
Correctly handle interruptions in blocking msgqueue syscalls.
Browse files Browse the repository at this point in the history
Reported-by: syzbot+63bde04529f701c76168@syzkaller.appspotmail.com
Reported-by: syzbot+69866b9a16ec29993e6a@syzkaller.appspotmail.com
PiperOrigin-RevId: 389084629
  • Loading branch information
mrahatm authored and gvisor-bot committed Aug 6, 2021
1 parent 15853bd commit 569f605
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 43 deletions.
53 changes: 33 additions & 20 deletions pkg/sentry/kernel/msgqueue/msgqueue.go
Expand Up @@ -208,37 +208,44 @@ func (r *Registry) FindByID(id ipc.ID) (*Queue, error) {

// Send appends a message to the message queue, and returns an error if sending
// fails. See msgsnd(2).
func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) {
func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) error {
// Try to perform a non-blocking send using queue.append. If EWOULDBLOCK
// is returned, start the blocking procedure. Otherwise, return normally.
creds := auth.CredentialsFromContext(ctx)
if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {

// Fast path: first attempt a non-blocking push.
if err := q.push(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
return err
}

if !wait {
return linuxerr.EAGAIN
}

// Slow path: at this point, the queue was found to be full, and we were
// asked to block.

e, ch := waiter.NewChannelEntry(nil)
q.senders.EventRegister(&e, waiter.EventOut)
defer q.senders.EventUnregister(&e)

// Note: we need to check again before blocking the first time since space
// may have become available.
for {
if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
break
if err := q.push(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
return err
}
if err := b.Block(ch); err != nil {
return err
}
b.Block(ch)
}

q.senders.EventUnregister(&e)
return err
}

// append appends a message to the queue's message list and notifies waiting
// push appends a message to the queue's message list and notifies waiting
// receivers that a message has been inserted. It returns an error if adding
// the message would cause the queue to exceed its maximum capacity, which can
// be used as a signal to block the task. Other errors should be returned as is.
func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error {
func (q *Queue) push(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error {
if m.Type <= 0 {
return linuxerr.EINVAL
}
Expand Down Expand Up @@ -295,15 +302,14 @@ func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials,
}

// Receive removes a message from the queue and returns it. See msgrcv(2).
func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) {
func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (*Message, error) {
if maxSize < 0 || maxSize > maxMessageBytes {
return nil, linuxerr.EINVAL
}
max := uint64(maxSize)

// Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK
// is returned, start the blocking procedure. Otherwise, return normally.
creds := auth.CredentialsFromContext(ctx)

// Fast path: first attempt a non-blocking pop.
if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
return msg, err
}
Expand All @@ -312,24 +318,30 @@ func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int
return nil, linuxerr.ENOMSG
}

// Slow path: at this point, the queue was found to be empty, and we were
// asked to block.

e, ch := waiter.NewChannelEntry(nil)
q.receivers.EventRegister(&e, waiter.EventIn)
defer q.receivers.EventUnregister(&e)

// Note: we need to check again before blocking the first time since a
// message may have become available.
for {
if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
break
if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
return msg, err
}
if err := b.Block(ch); err != nil {
return nil, err
}
b.Block(ch)
}
q.receivers.EventUnregister(&e)
return msg, err
}

// pop pops the first message from the queue that matches the given type. It
// returns an error for all the cases specified in msgrcv(2). If the queue is
// empty or no message of the specified type is available, a EWOULDBLOCK error
// is returned, which can then be used as a signal to block the process or fail.
func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) {
func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (*Message, error) {
q.mu.Lock()
defer q.mu.Unlock()

Expand All @@ -350,6 +362,7 @@ func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, m
}

// Get a message from the queue.
var msg *Message
switch {
case mType == 0:
msg = q.messages.Front()
Expand Down
3 changes: 2 additions & 1 deletion test/syscalls/linux/BUILD
Expand Up @@ -4173,10 +4173,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
"//test/util:signal_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)
Expand Down
138 changes: 116 additions & 22 deletions test/syscalls/linux/msgqueue.cc
Expand Up @@ -13,12 +13,15 @@
// limitations under the License.

#include <errno.h>
#include <signal.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>

#include "absl/synchronization/notification.h"
#include "absl/time/clock.h"
#include "test/util/capability_util.h"
#include "test/util/signal_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
Expand All @@ -31,6 +34,8 @@ constexpr int msgMax = 8192; // Max size for message in bytes.
constexpr int msgMni = 32000; // Max number of identifiers.
constexpr int msgMnb = 16384; // Default max size of message queue in bytes.

constexpr int kInterruptSignal = SIGALRM;

// Queue is a RAII class used to automatically clean message queues.
class Queue {
public:
Expand Down Expand Up @@ -73,6 +78,12 @@ bool operator==(msgbuf& a, msgbuf& b) {
return a.mtype == b.mtype;
}

// msgmax represents a buffer for the largest possible single message.
struct msgmax {
int64_t mtype;
char mtext[msgMax];
};

// Test simple creation and retrieval for msgget(2).
TEST(MsgqueueTest, MsgGet) {
const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
Expand Down Expand Up @@ -310,13 +321,6 @@ TEST(MsgqueueTest, MsgOpLimits) {
SyscallFailsWithErrno(EINVAL));

// Limit for queue.
// Use a buffer with the maximum mount of bytes that can be transformed to
// make it easier to exhaust the queue limit.
struct msgmax {
int64_t mtype;
char mtext[msgMax];
};

msgmax limit{1, ""};
for (size_t i = 0, msgCount = msgMnb / msgMax; i < msgCount; i++) {
EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), 0),
Expand Down Expand Up @@ -470,13 +474,6 @@ TEST(MsgqueueTest, MsgSndBlocking) {
Queue queue(msgget(IPC_PRIVATE, 0600));
ASSERT_THAT(queue.get(), SyscallSucceeds());

// Use a buffer with the maximum mount of bytes that can be transformed to
// make it easier to exhaust the queue limit.
struct msgmax {
int64_t mtype;
char mtext[msgMax];
};

msgmax buf{1, ""}; // Has max amount of bytes.

const size_t msgCount = msgMnb / msgMax; // Number of messages that can be
Expand All @@ -494,6 +491,8 @@ TEST(MsgqueueTest, MsgSndBlocking) {
SyscallSucceeds());
});

const DisableSave ds; // Too many syscalls.

// To increase the chance of the last msgsnd blocking before doing a msgrcv,
// we use MSG_COPY option to copy the last index in the queue. As long as
// MSG_COPY fails, the queue hasn't yet been filled. When MSG_COPY succeeds,
Expand All @@ -516,15 +515,9 @@ TEST(MsgqueueTest, MsgSndRmWhileBlocking) {
Queue queue(msgget(IPC_PRIVATE, 0600));
ASSERT_THAT(queue.get(), SyscallSucceeds());

// Use a buffer with the maximum mount of bytes that can be transformed to
// make it easier to exhaust the queue limit.
struct msgmax {
int64_t mtype;
char mtext[msgMax];
};
// Number of messages that can be sent without blocking.
const size_t msgCount = msgMnb / msgMax;

const size_t msgCount = msgMnb / msgMax; // Number of messages that can be
// sent without blocking.
ScopedThread t([&] {
// Fill the queue.
msgmax buf{1, ""};
Expand All @@ -540,6 +533,8 @@ TEST(MsgqueueTest, MsgSndRmWhileBlocking) {
EXPECT_TRUE((errno == EIDRM || errno == EINVAL));
});

const DisableSave ds; // Too many syscalls.

// Similar to MsgSndBlocking, we do this to increase the chance of msgsnd
// blocking before removing the queue.
msgmax rcv;
Expand Down Expand Up @@ -627,6 +622,105 @@ TEST(MsgqueueTest, MsgOpGeneral) {
ScopedThread s10(sender(4));
}

void empty_sighandler(int sig, siginfo_t* info, void* context) {}

TEST(MsgqueueTest, InterruptRecv) {
Queue queue(msgget(IPC_PRIVATE, 0600));
char buf[64];

absl::Notification done, exit;

// Thread calling msgrcv with no corresponding send. It would block forever,
// but we'll interrupt with a signal below.
ScopedThread t([&] {
struct sigaction sa = {};
sa.sa_sigaction = empty_sighandler;
sigfillset(&sa.sa_mask);
sa.sa_flags = SA_SIGINFO;
auto cleanup_sigaction =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kInterruptSignal, sa));
auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(
ScopedSignalMask(SIG_UNBLOCK, kInterruptSignal));

EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf), 0, 0),
SyscallFailsWithErrno(EINTR));

done.Notify();
exit.WaitForNotification();
});

const DisableSave ds; // Too many syscalls.

// We want the signal to arrive while msgrcv is blocking, but not after the
// thread has exited. Signals that arrive before msgrcv are no-ops.
do {
EXPECT_THAT(kill(getpid(), kInterruptSignal), SyscallSucceeds());
absl::SleepFor(absl::Milliseconds(100)); // Rate limit.
} while (!done.HasBeenNotified());

exit.Notify();
t.Join();
}

TEST(MsgqueueTest, InterruptSend) {
Queue queue(msgget(IPC_PRIVATE, 0600));
msgmax buf{1, ""};
// Number of messages that can be sent without blocking.
const size_t msgCount = msgMnb / msgMax;

// Fill the queue.
for (size_t i = 0; i < msgCount; i++) {
ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
SyscallSucceeds());
}

absl::Notification done, exit;

// Thread calling msgsnd on a full queue. It would block forever, but we'll
// interrupt with a signal below.
ScopedThread t([&] {
struct sigaction sa = {};
sa.sa_sigaction = empty_sighandler;
sigfillset(&sa.sa_mask);
sa.sa_flags = SA_SIGINFO;
auto cleanup_sigaction =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(kInterruptSignal, sa));
auto sa_cleanup = ASSERT_NO_ERRNO_AND_VALUE(
ScopedSignalMask(SIG_UNBLOCK, kInterruptSignal));

EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
SyscallFailsWithErrno(EINTR));

done.Notify();
exit.WaitForNotification();
});

const DisableSave ds; // Too many syscalls.

// We want the signal to arrive while msgsnd is blocking, but not after the
// thread has exited. Signals that arrive before msgsnd are no-ops.
do {
EXPECT_THAT(kill(getpid(), kInterruptSignal), SyscallSucceeds());
absl::SleepFor(absl::Milliseconds(100)); // Rate limit.
} while (!done.HasBeenNotified());

exit.Notify();
t.Join();
}

} // namespace
} // namespace testing
} // namespace gvisor

int main(int argc, char** argv) {
// Some tests depend on delivering a signal to the main thread. Block the
// target signal so that any other threads created by TestInit will also have
// the signal blocked.
sigset_t set;
sigemptyset(&set);
sigaddset(&set, gvisor::testing::kInterruptSignal);
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);

gvisor::testing::TestInit(&argc, &argv);
return gvisor::testing::RunAllTests();
}

0 comments on commit 569f605

Please sign in to comment.