Skip to content
Permalink
Browse files Browse the repository at this point in the history
1.4 - Do not call into the VM unless the VM Context has been created. (
…#24)

* Ensure that the in VM Context is created before onDone is called.

Signed-off-by: John Plevyak <jplevyak@gmail.com>

* Update as per offline discussion.

Signed-off-by: John Plevyak <jplevyak@gmail.com>

* Set in_vm_context_created_ in onNetworkNewConnection.

Signed-off-by: John Plevyak <jplevyak@gmail.com>

* Add guards to other network calls.

Signed-off-by: John Plevyak <jplevyak@gmail.com>

* Fix common/wasm tests.

Signed-off-by: John Plevyak <jplevyak@gmail.com>

* Patch tests.

Signed-off-by: John Plevyak <jplevyak@gmail.com>

* Remove unecessary file from cherry-pick.

Signed-off-by: John Plevyak <jplevyak@gmail.com>
  • Loading branch information
jplevyak committed May 9, 2020
1 parent beea846 commit 8788a3c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
28 changes: 15 additions & 13 deletions source/extensions/common/wasm/wasm.cc
Expand Up @@ -1699,6 +1699,7 @@ void Context::onStart(absl::string_view root_id, absl::string_view vm_configurat
auto config_addr = wasm_->copyString(vm_configuration);
wasm_->onStart_(this, id_, root_id_addr, root_id.size(), config_addr, vm_configuration.size());
}
in_vm_context_created_ = true;
}

bool Context::validateConfiguration(absl::string_view configuration) {
Expand All @@ -1725,6 +1726,7 @@ void Context::onCreate(uint32_t root_context_id) {

Network::FilterStatus Context::onNetworkNewConnection() {
onCreate(root_context_id_);
in_vm_context_created_ = true;
if (!wasm_->onNewConnection_) {
return Network::FilterStatus::Continue;
}
Expand All @@ -1735,7 +1737,7 @@ Network::FilterStatus Context::onNetworkNewConnection() {
}

Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_stream) {
if (!wasm_->onDownstreamData_) {
if (!in_vm_context_created_ || !wasm_->onDownstreamData_) {
return Network::FilterStatus::Continue;
}
auto result = wasm_->onDownstreamData_(this, id_, static_cast<uint32_t>(data_length),
Expand All @@ -1745,7 +1747,7 @@ Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_str
}

Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_stream) {
if (!wasm_->onUpstreamData_) {
if (!in_vm_context_created_ || !wasm_->onUpstreamData_) {
return Network::FilterStatus::Continue;
}
auto result = wasm_->onUpstreamData_(this, id_, static_cast<uint32_t>(data_length),
Expand All @@ -1755,13 +1757,13 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea
}

void Context::onDownstreamConnectionClose(PeerType peer_type) {
if (wasm_->onDownstreamConnectionClose_) {
if (in_vm_context_created_ && wasm_->onDownstreamConnectionClose_) {
wasm_->onDownstreamConnectionClose_(this, id_, static_cast<uint32_t>(peer_type));
}
}

void Context::onUpstreamConnectionClose(PeerType peer_type) {
if (wasm_->onUpstreamConnectionClose_) {
if (in_vm_context_created_ && wasm_->onUpstreamConnectionClose_) {
wasm_->onUpstreamConnectionClose_(this, id_, static_cast<uint32_t>(peer_type));
}
}
Expand All @@ -1785,7 +1787,7 @@ Http::FilterHeadersStatus Context::onRequestHeaders() {
}

Http::FilterDataStatus Context::onRequestBody(int body_buffer_length, bool end_of_stream) {
if (!wasm_->onRequestBody_) {
if (!in_vm_context_created_ || !wasm_->onRequestBody_) {
return Http::FilterDataStatus::Continue;
}
switch (wasm_
Expand All @@ -1804,7 +1806,7 @@ Http::FilterDataStatus Context::onRequestBody(int body_buffer_length, bool end_o
}

Http::FilterTrailersStatus Context::onRequestTrailers() {
if (!wasm_->onRequestTrailers_) {
if (!in_vm_context_created_ || !wasm_->onRequestTrailers_) {
return Http::FilterTrailersStatus::Continue;
}
if (wasm_->onRequestTrailers_(this, id_).u64_ == 0) {
Expand All @@ -1814,7 +1816,7 @@ Http::FilterTrailersStatus Context::onRequestTrailers() {
}

Http::FilterMetadataStatus Context::onRequestMetadata() {
if (!wasm_->onRequestMetadata_) {
if (!in_vm_context_created_ || !wasm_->onRequestMetadata_) {
return Http::FilterMetadataStatus::Continue;
}
if (wasm_->onRequestMetadata_(this, id_).u64_ == 0) {
Expand Down Expand Up @@ -1842,7 +1844,7 @@ Http::FilterHeadersStatus Context::onResponseHeaders() {
}

Http::FilterDataStatus Context::onResponseBody(int body_buffer_length, bool end_of_stream) {
if (!wasm_->onResponseBody_) {
if (!in_vm_context_created_ || !wasm_->onResponseBody_) {
return Http::FilterDataStatus::Continue;
}
switch (wasm_
Expand All @@ -1861,7 +1863,7 @@ Http::FilterDataStatus Context::onResponseBody(int body_buffer_length, bool end_
}

Http::FilterTrailersStatus Context::onResponseTrailers() {
if (!wasm_->onResponseTrailers_) {
if (!in_vm_context_created_ || !wasm_->onResponseTrailers_) {
return Http::FilterTrailersStatus::Continue;
}
if (wasm_->onResponseTrailers_(this, id_).u64_ == 0) {
Expand All @@ -1871,7 +1873,7 @@ Http::FilterTrailersStatus Context::onResponseTrailers() {
}

Http::FilterMetadataStatus Context::onResponseMetadata() {
if (!wasm_->onResponseMetadata_) {
if (!in_vm_context_created_ || !wasm_->onResponseMetadata_) {
return Http::FilterMetadataStatus::Continue;
}
if (wasm_->onResponseMetadata_(this, id_).u64_ == 0) {
Expand Down Expand Up @@ -2445,19 +2447,19 @@ void Context::onDestroy() {
}

void Context::onDone() {
if (wasm_->onDone_) {
if (in_vm_context_created_ && wasm_->onDone_) {
wasm_->onDone_(this, id_);
}
}

void Context::onLog() {
if (wasm_->onLog_) {
if (in_vm_context_created_ && wasm_->onLog_) {
wasm_->onLog_(this, id_);
}
}

void Context::onDelete() {
if (wasm_->onDelete_) {
if (in_vm_context_created_ && wasm_->onDelete_) {
wasm_->onDelete_(this, id_);
}
}
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/common/wasm/wasm.h
Expand Up @@ -413,6 +413,8 @@ class Context : public Logger::Loggable<Logger::Id::wasm>,
// Connection
virtual bool isSsl();

void setInVmContextCreatedForTesting() { in_vm_context_created_ = true; }

protected:
friend class Wasm;
friend struct AsyncClientHandler;
Expand Down
3 changes: 2 additions & 1 deletion test/extensions/wasm/wasm_test.cc
Expand Up @@ -184,7 +184,7 @@ TEST_P(WasmTest, DivByZero) {
auto context = std::make_unique<TestContext>(wasm.get());
EXPECT_CALL(*context, scriptLog_(spdlog::level::err, Eq("before div by zero")));
EXPECT_TRUE(wasm->initialize(code, false));
wasm->setContext(context.get());
context->setInVmContextCreatedForTesting();

if (GetParam() == "v8") {
EXPECT_THROW_WITH_MESSAGE(
Expand Down Expand Up @@ -388,6 +388,7 @@ TEST_P(WasmTest, StatsHighLevel) {
"{{ test_rundir }}/test/extensions/wasm/test_data/stats_cpp.wasm"));
EXPECT_FALSE(code.empty());
auto context = std::make_unique<TestContext>(wasm.get());
context->setInVmContextCreatedForTesting();

EXPECT_CALL(*context, scriptLog_(spdlog::level::trace, Eq("get counter = 1")));
EXPECT_CALL(*context, scriptLog_(spdlog::level::debug, Eq("get counter = 2")));
Expand Down

0 comments on commit 8788a3c

Please sign in to comment.