Skip to content

Commit

Permalink
[binder] Use AParcel_getDataSize() in flow-control (#27257)
Browse files Browse the repository at this point in the history
  • Loading branch information
TaWeiTu committed Sep 11, 2021
1 parent 52e5b64 commit 71ceae7
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ const absl::string_view
void TransportStreamReceiverImpl::RegisterRecvInitialMetadata(
StreamIdentifier id, InitialMetadataCallbackType cb) {
gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
GPR_ASSERT(initial_metadata_cbs_.count(id) == 0);
absl::StatusOr<Metadata> initial_metadata{};
{
grpc_core::MutexLock l(&m_);
GPR_ASSERT(initial_metadata_cbs_.count(id) == 0);
auto iter = pending_initial_metadata_.find(id);
if (iter == pending_initial_metadata_.end()) {
if (trailing_metadata_recvd_.count(id)) {
Expand All @@ -59,10 +59,10 @@ void TransportStreamReceiverImpl::RegisterRecvInitialMetadata(
void TransportStreamReceiverImpl::RegisterRecvMessage(
StreamIdentifier id, MessageDataCallbackType cb) {
gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
GPR_ASSERT(message_cbs_.count(id) == 0);
absl::StatusOr<std::string> message{};
{
grpc_core::MutexLock l(&m_);
GPR_ASSERT(message_cbs_.count(id) == 0);
auto iter = pending_message_.find(id);
if (iter == pending_message_.end()) {
// If we'd already received trailing-metadata and there's no pending
Expand Down Expand Up @@ -93,10 +93,10 @@ void TransportStreamReceiverImpl::RegisterRecvMessage(
void TransportStreamReceiverImpl::RegisterRecvTrailingMetadata(
StreamIdentifier id, TrailingMetadataCallbackType cb) {
gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_);
GPR_ASSERT(trailing_metadata_cbs_.count(id) == 0);
std::pair<absl::StatusOr<Metadata>, int> trailing_metadata{};
{
grpc_core::MutexLock l(&m_);
GPR_ASSERT(trailing_metadata_cbs_.count(id) == 0);
auto iter = pending_trailing_metadata_.find(id);
if (iter == pending_trailing_metadata_.end()) {
trailing_metadata_cbs_[id] = std::move(cb);
Expand Down
2 changes: 2 additions & 0 deletions src/core/ext/transport/binder/wire_format/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class WritableParcel {
public:
virtual ~WritableParcel() = default;
virtual int32_t GetDataPosition() const = 0;
virtual int32_t GetDataSize() const = 0;
virtual absl::Status SetDataPosition(int32_t pos) = 0;
virtual absl::Status WriteInt32(int32_t data) = 0;
virtual absl::Status WriteInt64(int64_t data) = 0;
Expand All @@ -67,6 +68,7 @@ class WritableParcel {
class ReadableParcel {
public:
virtual ~ReadableParcel() = default;
virtual int32_t GetDataSize() const = 0;
virtual absl::Status ReadInt32(int32_t* data) const = 0;
virtual absl::Status ReadInt64(int64_t* data) const = 0;
virtual absl::Status ReadBinder(std::unique_ptr<Binder>* data) const = 0;
Expand Down
22 changes: 21 additions & 1 deletion src/core/ext/transport/binder/wire_format/binder_android.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
#include "src/core/ext/transport/binder/wire_format/binder_android.h"
#include "src/core/lib/gprpp/sync.h"

extern "C" {
// TODO(mingcl): This function is introduced at API level 32 and is not
// available in any NDK release yet. So we export it weakly so that we can use
// it without triggering undefined reference error. Its purpose is to disable
// header in Parcel to conform to the BinderChannel wire format.
extern "C" {
extern void AIBinder_Class_disableInterfaceTokenHeader(AIBinder_Class* clazz)
__attribute__((weak));
// This is released in API level 31.
extern int32_t AParcel_getDataSize(const AParcel* parcel) __attribute__((weak));
}

namespace grpc_binder {
Expand Down Expand Up @@ -194,6 +196,15 @@ int32_t WritableParcelAndroid::GetDataPosition() const {
return AParcel_getDataPosition(parcel_);
}

int32_t WritableParcelAndroid::GetDataSize() const {
if (AParcel_getDataSize) {
return AParcel_getDataSize(parcel_);
} else {
gpr_log(GPR_INFO, "[Warning] AParcel_getDataSize is not available");
return 0;
}
}

absl::Status WritableParcelAndroid::SetDataPosition(int32_t pos) {
return AParcel_setDataPosition(parcel_, pos) == STATUS_OK
? absl::OkStatus()
Expand Down Expand Up @@ -233,6 +244,15 @@ absl::Status WritableParcelAndroid::WriteByteArray(const int8_t* buffer,
: absl::InternalError("AParcel_writeByteArray failed");
}

int32_t ReadableParcelAndroid::GetDataSize() const {
if (AParcel_getDataSize) {
return AParcel_getDataSize(parcel_);
} else {
gpr_log(GPR_INFO, "[Warning] AParcel_getDataSize is not available");
return -1;
}
}

absl::Status ReadableParcelAndroid::ReadInt32(int32_t* data) const {
return AParcel_readInt32(parcel_, data) == STATUS_OK
? absl::OkStatus()
Expand Down
2 changes: 2 additions & 0 deletions src/core/ext/transport/binder/wire_format/binder_android.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class WritableParcelAndroid final : public WritableParcel {
~WritableParcelAndroid() override = default;

int32_t GetDataPosition() const override;
int32_t GetDataSize() const override;
absl::Status SetDataPosition(int32_t pos) override;
absl::Status WriteInt32(int32_t data) override;
absl::Status WriteInt64(int64_t data) override;
Expand All @@ -66,6 +67,7 @@ class ReadableParcelAndroid final : public ReadableParcel {
: parcel_(const_cast<AParcel*>(parcel)) {}
~ReadableParcelAndroid() override = default;

int32_t GetDataSize() const override;
absl::Status ReadInt32(int32_t* data) const override;
absl::Status ReadInt64(int64_t* data) const override;
absl::Status ReadBinder(std::unique_ptr<Binder>* data) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ absl::Status WireReaderImpl::ProcessStreamingTransactionImpl(
transaction_code_t code, const ReadableParcel* parcel,
int* cancellation_flags) {
GPR_ASSERT(cancellation_flags);
num_incoming_bytes_ += parcel->GetDataSize();

int flags;
RETURN_IF_ERROR(parcel->ReadInt32(&flags));
Expand Down Expand Up @@ -342,8 +343,6 @@ absl::Status WireReaderImpl::ProcessStreamingTransactionImpl(
}
gpr_log(GPR_INFO, "msg_data = %s", msg_data.c_str());
message_buffer_[code] += msg_data;
// TODO(waynetu): This should be parcel->GetDataSize().
num_incoming_bytes_ += count;
if ((flags & kFlagMessageDataIsPartial) == 0) {
std::string s = std::move(message_buffer_[code]);
message_buffer_.erase(code);
Expand Down
3 changes: 1 addition & 2 deletions src/core/ext/transport/binder/wire_format/wire_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,9 @@ absl::Status WireWriterImpl::RpcCall(const Transaction& tx) {
if (flags & kFlagSuffix) {
RETURN_IF_ERROR(WriteTrailingMetadata(tx, parcel));
}
num_outgoing_bytes_ += parcel->GetDataSize();
RETURN_IF_ERROR(binder_->Transact(BinderTransportTxCode(tx.GetTxCode())));
bytes_sent += size;
// TODO(waynetu): This should be parcel->GetDataSize().
num_outgoing_bytes_ += size;
}
return absl::OkStatus();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,25 @@ TEST_P(End2EndBinderTransportTest, BiDirStreamingCallThroughFakeBinderChannel) {
server->Shutdown();
}

TEST_P(End2EndBinderTransportTest, LargeMessage) {
TEST_P(End2EndBinderTransportTest, LargeMessages) {
grpc::ChannelArguments args;
grpc::ServerBuilder builder;
end2end_testing::EchoServer service;
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server = builder.BuildAndStart();
std::shared_ptr<grpc::Channel> channel = BinderChannel(server.get(), args);
std::unique_ptr<EchoService::Stub> stub = EchoService::NewStub(channel);
grpc::ClientContext context;
EchoRequest request;
EchoResponse response;
request.set_text(std::string(1000000, 'a'));
grpc::Status status = stub->EchoUnaryCall(&context, request, &response);
EXPECT_TRUE(status.ok());
EXPECT_EQ(response.text(), std::string(1000000, 'a'));

for (size_t size = 1; size <= 1024 * 1024; size *= 4) {
grpc::ClientContext context;
EchoRequest request;
EchoResponse response;
request.set_text(std::string(size, 'a'));
grpc::Status status = stub->EchoUnaryCall(&context, request, &response);
EXPECT_TRUE(status.ok());
EXPECT_EQ(response.text().size(), size);
EXPECT_TRUE(std::all_of(response.text().begin(), response.text().end(),
[](char c) { return c == 'a'; }));
}
server->Shutdown();
}

Expand Down
9 changes: 9 additions & 0 deletions test/core/transport/binder/end2end/fake_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ FakeWritableParcel::FakeWritableParcel() : data_(1) {}

int32_t FakeWritableParcel::GetDataPosition() const { return data_position_; }

int32_t FakeWritableParcel::GetDataSize() const { return data_size_; }

absl::Status FakeWritableParcel::SetDataPosition(int32_t pos) {
if (data_.size() < static_cast<size_t>(pos) + 1) {
data_.resize(pos + 1);
Expand All @@ -39,34 +41,41 @@ absl::Status FakeWritableParcel::SetDataPosition(int32_t pos) {
absl::Status FakeWritableParcel::WriteInt32(int32_t data) {
data_[data_position_] = data;
SetDataPosition(data_position_ + 1).IgnoreError();
data_size_ += 4;
return absl::OkStatus();
}

absl::Status FakeWritableParcel::WriteInt64(int64_t data) {
data_[data_position_] = data;
SetDataPosition(data_position_ + 1).IgnoreError();
data_size_ += 8;
return absl::OkStatus();
}

absl::Status FakeWritableParcel::WriteBinder(HasRawBinder* binder) {
data_[data_position_] = binder->GetRawBinder();
SetDataPosition(data_position_ + 1).IgnoreError();
data_size_ += 8;
return absl::OkStatus();
}

absl::Status FakeWritableParcel::WriteString(absl::string_view s) {
data_[data_position_] = std::string(s);
SetDataPosition(data_position_ + 1).IgnoreError();
data_size_ += s.size();
return absl::OkStatus();
}

absl::Status FakeWritableParcel::WriteByteArray(const int8_t* buffer,
int32_t length) {
data_[data_position_] = std::vector<int8_t>(buffer, buffer + length);
SetDataPosition(data_position_ + 1).IgnoreError();
data_size_ += length;
return absl::OkStatus();
}

int32_t FakeReadableParcel::GetDataSize() const { return data_size_; }

absl::Status FakeReadableParcel::ReadInt32(int32_t* data) const {
if (data_position_ >= data_.size() ||
!absl::holds_alternative<int32_t>(data_[data_position_])) {
Expand Down
21 changes: 20 additions & 1 deletion test/core/transport/binder/end2end/fake_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class FakeWritableParcel final : public WritableParcel {
public:
FakeWritableParcel();
int32_t GetDataPosition() const override;
int32_t GetDataSize() const override;
absl::Status SetDataPosition(int32_t pos) override;
absl::Status WriteInt32(int32_t data) override;
absl::Status WriteInt64(int64_t data) override;
Expand All @@ -96,6 +97,7 @@ class FakeWritableParcel final : public WritableParcel {
private:
FakeData data_;
size_t data_position_ = 0;
int32_t data_size_ = 0;
};

// A fake readable parcel.
Expand All @@ -104,7 +106,23 @@ class FakeWritableParcel final : public WritableParcel {
// methods to retrieve those data in the receiving end.
class FakeReadableParcel final : public ReadableParcel {
public:
explicit FakeReadableParcel(FakeData data) : data_(std::move(data)) {}
explicit FakeReadableParcel(FakeData data) : data_(std::move(data)) {
for (auto& d : data_) {
if (absl::holds_alternative<int32_t>(d)) {
data_size_ += 4;
} else if (absl::holds_alternative<int64_t>(d)) {
data_size_ += 8;
} else if (absl::holds_alternative<void*>(d)) {
data_size_ += 8;
} else if (absl::holds_alternative<std::string>(d)) {
data_size_ += absl::get<std::string>(d).size();
} else {
data_size_ += absl::get<std::vector<int8_t>>(d).size();
}
}
}

int32_t GetDataSize() const override;
absl::Status ReadInt32(int32_t* data) const override;
absl::Status ReadInt64(int64_t* data) const override;
absl::Status ReadBinder(std::unique_ptr<Binder>* data) const override;
Expand All @@ -114,6 +132,7 @@ class FakeReadableParcel final : public ReadableParcel {
private:
const FakeData data_;
mutable size_t data_position_ = 0;
int32_t data_size_ = 0;
};

class FakeBinder;
Expand Down
2 changes: 2 additions & 0 deletions test/core/transport/binder/mock_objects.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace grpc_binder {
class MockWritableParcel : public WritableParcel {
public:
MOCK_METHOD(int32_t, GetDataPosition, (), (const, override));
MOCK_METHOD(int32_t, GetDataSize, (), (const, override));
MOCK_METHOD(absl::Status, SetDataPosition, (int32_t), (override));
MOCK_METHOD(absl::Status, WriteInt32, (int32_t), (override));
MOCK_METHOD(absl::Status, WriteInt64, (int64_t), (override));
Expand All @@ -40,6 +41,7 @@ class MockWritableParcel : public WritableParcel {

class MockReadableParcel : public ReadableParcel {
public:
MOCK_METHOD(int32_t, GetDataSize, (), (const, override));
MOCK_METHOD(absl::Status, ReadInt32, (int32_t*), (const, override));
MOCK_METHOD(absl::Status, ReadInt64, (int64_t*), (const, override));
MOCK_METHOD(absl::Status, ReadBinder, (std::unique_ptr<Binder>*),
Expand Down
5 changes: 4 additions & 1 deletion test/core/transport/binder/wire_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class WireReaderTest : public ::testing::Test {
std::shared_ptr<StrictMock<MockTransportStreamReceiver>>
transport_stream_receiver_;
WireReaderImpl wire_reader_;
StrictMock<MockReadableParcel> mock_readable_parcel_;
MockReadableParcel mock_readable_parcel_;
};

MATCHER_P(StatusOrStrEq, target, "") {
Expand Down Expand Up @@ -279,6 +279,8 @@ TEST_F(WireReaderTest, InBoundFlowControl) {
::testing::InSequence sequence;
UnblockSetupTransport();

// data size
EXPECT_CALL(mock_readable_parcel_, GetDataSize).WillOnce(Return(1000));
// flag
ExpectReadInt32(kFlagMessageData | kFlagMessageDataIsPartial);
// sequence number
Expand All @@ -292,6 +294,7 @@ TEST_F(WireReaderTest, InBoundFlowControl) {
// Data is not completed. No callback will be triggered.
EXPECT_TRUE(CallProcessTransaction(kFirstCallId).ok());

EXPECT_CALL(mock_readable_parcel_, GetDataSize).WillOnce(Return(1000));
// flag
ExpectReadInt32(kFlagMessageData);
// sequence number
Expand Down
13 changes: 11 additions & 2 deletions test/core/transport/binder/wire_writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
namespace grpc_binder {

using ::testing::Return;
using ::testing::StrictMock;

MATCHER_P(StrEqInt8Ptr, target, "") {
return std::string(reinterpret_cast<const char*>(arg), target.size()) ==
Expand All @@ -37,7 +36,7 @@ MATCHER_P(StrEqInt8Ptr, target, "") {
TEST(WireWriterTest, RpcCall) {
auto mock_binder = absl::make_unique<MockBinder>();
MockBinder& mock_binder_ref = *mock_binder;
StrictMock<MockWritableParcel> mock_writable_parcel;
MockWritableParcel mock_writable_parcel;
ON_CALL(mock_binder_ref, GetWritableParcel)
.WillByDefault(Return(&mock_writable_parcel));
WireWriterImpl wire_writer(std::move(mock_binder));
Expand Down Expand Up @@ -176,19 +175,24 @@ TEST(WireWriterTest, RpcCall) {
WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial));
EXPECT_CALL(mock_writable_parcel, WriteInt32(0));
ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a'));
EXPECT_CALL(mock_writable_parcel, GetDataSize)
.WillOnce(Return(WireWriterImpl::kBlockSize));
EXPECT_CALL(mock_binder_ref,
Transact(BinderTransportTxCode(kFirstCallId + 2)));

EXPECT_CALL(mock_writable_parcel,
WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial));
EXPECT_CALL(mock_writable_parcel, WriteInt32(1));
ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a'));
EXPECT_CALL(mock_writable_parcel, GetDataSize)
.WillOnce(Return(WireWriterImpl::kBlockSize));
EXPECT_CALL(mock_binder_ref,
Transact(BinderTransportTxCode(kFirstCallId + 2)));

EXPECT_CALL(mock_writable_parcel, WriteInt32(kFlagMessageData));
EXPECT_CALL(mock_writable_parcel, WriteInt32(2));
ExpectWriteByteArray("a");
EXPECT_CALL(mock_writable_parcel, GetDataSize).WillOnce(Return(1));
EXPECT_CALL(mock_binder_ref,
Transact(BinderTransportTxCode(kFirstCallId + 2)));

Expand All @@ -206,20 +210,25 @@ TEST(WireWriterTest, RpcCall) {
EXPECT_CALL(mock_writable_parcel, WriteString(absl::string_view("123")));
EXPECT_CALL(mock_writable_parcel, WriteInt32(0));
ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a'));
EXPECT_CALL(mock_writable_parcel, GetDataSize)
.WillOnce(Return(WireWriterImpl::kBlockSize));
EXPECT_CALL(mock_binder_ref,
Transact(BinderTransportTxCode(kFirstCallId + 3)));

EXPECT_CALL(mock_writable_parcel,
WriteInt32(kFlagMessageData | kFlagMessageDataIsPartial));
EXPECT_CALL(mock_writable_parcel, WriteInt32(1));
ExpectWriteByteArray(std::string(WireWriterImpl::kBlockSize, 'a'));
EXPECT_CALL(mock_writable_parcel, GetDataSize)
.WillOnce(Return(WireWriterImpl::kBlockSize));
EXPECT_CALL(mock_binder_ref,
Transact(BinderTransportTxCode(kFirstCallId + 3)));

EXPECT_CALL(mock_writable_parcel,
WriteInt32(kFlagMessageData | kFlagSuffix));
EXPECT_CALL(mock_writable_parcel, WriteInt32(2));
ExpectWriteByteArray("a");
EXPECT_CALL(mock_writable_parcel, GetDataSize).WillOnce(Return(1));
EXPECT_CALL(mock_binder_ref,
Transact(BinderTransportTxCode(kFirstCallId + 3)));

Expand Down

0 comments on commit 71ceae7

Please sign in to comment.