Skip to content

Commit

Permalink
Give Cpp2ConnContext access to incoming request headers.
Browse files Browse the repository at this point in the history
Summary:
Give the connection-context access to incomming request headers.

This is to actually give AdmissionStrategies some context when receiving
requests. For example PerClientIdAdmissionStrategy doesn't have access to
ClientId without these headers.

Reviewed By: stevegury

Differential Revision: D13983173

fbshipit-source-id: e66b02da01331aee77ed257ab62ed89b773a791b
  • Loading branch information
Rodolfo Granata authored and facebook-github-bot committed Feb 14, 2019
1 parent cbebb00 commit 9443939
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 147 deletions.
4 changes: 0 additions & 4 deletions thrift/lib/cpp2/server/AdmissionController.h
Expand Up @@ -23,10 +23,6 @@
namespace apache {
namespace thrift {

class BaseThriftServer;
class ResponseChannelRequest;
class Cpp2ConnContext;

class AdmissionController {
public:
using MetricReportFn =
Expand Down
2 changes: 1 addition & 1 deletion thrift/lib/cpp2/server/Cpp2Connection.cpp
Expand Up @@ -321,7 +321,7 @@ void Cpp2Connection::requestReceived(unique_ptr<ResponseChannelRequest>&& req) {

auto admissionStrategy = worker_->getServer()->getAdmissionStrategy();
auto admissionController =
admissionStrategy->select(methodName, *req, context_);
admissionStrategy->select(methodName, hreq->getHeader());
if (!admissionController->admit()) {
killRequest(
*req,
Expand Down
Expand Up @@ -38,8 +38,7 @@ class AcceptAllAdmissionStrategy : public AdmissionStrategy {

std::shared_ptr<AdmissionController> select(
const std::string&,
const ResponseChannelRequest&,
const Cpp2ConnContext&) override {
const transport::THeader*) override {
return admissionController_;
}

Expand Down
10 changes: 3 additions & 7 deletions thrift/lib/cpp2/server/admission_strategy/AdmissionStrategy.h
Expand Up @@ -18,14 +18,12 @@

#include <memory>

#include <thrift/lib/cpp2/async/ResponseChannel.h>
#include <thrift/lib/cpp/transport/THeader.h>
#include <thrift/lib/cpp2/server/AdmissionController.h>

namespace apache {
namespace thrift {

class BaseThriftServer;

class AdmissionStrategy {
public:
enum Type { ACCEPT_ALL = 0, GLOBAL = 1, PER_CLIENT_ID = 2, PRIORITY = 3 };
Expand All @@ -39,13 +37,11 @@ class AdmissionStrategy {
* Select an AdmissionController to be used for this specific request.
* This selection can be made based on the arguments which are:
* - methodName: the name of the Thrift method called
* - request: An object representing the serialized request
* - connContext: The connection context allowing you to access metadata
* - tHeader: transport header allowing access to request headers
*/
virtual std::shared_ptr<AdmissionController> select(
const std::string& methodName,
const ResponseChannelRequest& request,
const Cpp2ConnContext& connContext) = 0;
const transport::THeader* tHeader) = 0;

virtual void reportMetrics(
const MetricReportFn&,
Expand Down
Expand Up @@ -47,8 +47,7 @@ class GlobalAdmissionStrategy : public AdmissionStrategy {
*/
std::shared_ptr<AdmissionController> select(
const std::string&,
const ResponseChannelRequest&,
const Cpp2ConnContext&) override {
const transport::THeader*) override {
return admissionController_;
}

Expand Down
Expand Up @@ -62,15 +62,14 @@ class PerClientIdAdmissionStrategy : public AdmissionStrategy {
*/
std::shared_ptr<AdmissionController> select(
const std::string&,
const ResponseChannelRequest&,
const Cpp2ConnContext& connContext) override {
const auto* headers = connContext.getHeadersPtr();
if (headers == nullptr) {
const transport::THeader* theader) override {
if (theader == nullptr) {
return wildcardController_;
}

auto headersIt = headers->find(clientIdHeaderName_);
if (headersIt == headers->end() || headersIt->second == kWildcard) {
const auto& headers = theader->getHeaders();
auto headersIt = headers.find(clientIdHeaderName_);
if (headersIt == headers.end() || headersIt->second == kWildcard) {
return wildcardController_;
}

Expand Down
Expand Up @@ -78,7 +78,7 @@ class PriorityAdmissionStrategy : public AdmissionStrategy {
*
* `priorities` is a map of clientId to absolute priority.
* `factory` is the function used to create a new admission controller
* `clientIdHeaderName` is the header read from the connexion context to
* `clientIdHeaderName` is the header read from request headers to
* identify a client
*/
PriorityAdmissionStrategy(
Expand Down Expand Up @@ -112,9 +112,8 @@ class PriorityAdmissionStrategy : public AdmissionStrategy {
*/
std::shared_ptr<AdmissionController> select(
const std::string&,
const ResponseChannelRequest&,
const Cpp2ConnContext& connContext) override {
auto bucketIndex = computeBucketIndex(connContext);
const transport::THeader* tHeader) override {
auto bucketIndex = computeBucketIndex(tHeader);
if (bucketIndex < 0) {
return denyAdmissionController_;
}
Expand Down Expand Up @@ -144,7 +143,7 @@ class PriorityAdmissionStrategy : public AdmissionStrategy {

private:
/**
* Compute a bucket index based on the connection context.
* Compute a bucket index based on the client-id
*
* This method compute a bucket index based on the priorities.
* E.g. for priorities like this: {"A": 1, "B": 3, WILDCARD: 1}
Expand All @@ -153,8 +152,8 @@ class PriorityAdmissionStrategy : public AdmissionStrategy {
* "B" -> 1, 2, 3 (returned in a round-robin way)
* WILDCARD -> 4
*/
int computeBucketIndex(const Cpp2ConnContext& connContext) {
auto& priority = getPriority(connContext);
int computeBucketIndex(const transport::THeader* theader) {
auto& priority = getPriority(theader);
if (priority.priority == 0) {
return -1; // deny all requests if priority == 0
}
Expand All @@ -163,11 +162,11 @@ class PriorityAdmissionStrategy : public AdmissionStrategy {
return index;
}

Priority& getPriority(const Cpp2ConnContext& connContext) {
const auto* headers = connContext.getHeadersPtr();
if (headers != nullptr) {
auto clientIdIt = headers->find(clientIdHeaderName_);
if (clientIdIt != headers->end()) {
Priority& getPriority(const transport::THeader* theader) {
if (theader != nullptr) {
const auto& headers = theader->getHeaders();
auto clientIdIt = headers.find(clientIdHeaderName_);
if (clientIdIt != headers.end()) {
auto priorityIt = priorities_.find(clientIdIt->second);
if (priorityIt != priorities_.end()) {
return priorityIt->second;
Expand Down
Expand Up @@ -61,12 +61,11 @@ class WhitelistAdmissionStrategy : public AdmissionStrategy {

std::shared_ptr<AdmissionController> select(
const std::string& methodName,
const ResponseChannelRequest& request,
const Cpp2ConnContext& connContext) override {
const transport::THeader* tHeader) override {
if (whitelist_.find(methodName) != whitelist_.end()) {
return acceptAllAdmissionController_;
}
return innerStrategy_.select(methodName, request, connContext);
return innerStrategy_.select(methodName, tHeader);
}

void reportMetrics(
Expand Down
150 changes: 40 additions & 110 deletions thrift/lib/cpp2/test/AdmissionStrategyTest.cpp
Expand Up @@ -44,23 +44,6 @@ namespace thrift {

FakeClock::time_point FakeClock::now_us_;

class DummyRequest : public ResponseChannelRequest {
bool isActive() override {
return true;
}
void cancel() override {}
bool isOneway() override {
return false;
}
void sendReply(std::unique_ptr<folly::IOBuf>&&, MessageChannel::SendCallback*)
override {}

void sendErrorWrapped(
folly::exception_wrapper,
std::string,
MessageChannel::SendCallback*) override {}
};

class DummyController : public AdmissionController {
public:
bool admit() override {
Expand All @@ -70,43 +53,6 @@ class DummyController : public AdmissionController {
void returnedResponse() override {}
};

class DummyConnContext : public Cpp2ConnContext {
public:
DummyConnContext() {
setRequestHeader(&reqHeader_);
}

void setReadHeader(const std::string& key, const std::string& value) {
auto headers = reqHeader_.releaseHeaders();
headers.insert({key, value});
reqHeader_.setReadHeaders(std::move(headers));
}

private:
THeader reqHeader_;
};

class DummyContextTest : public testing::Test {};

TEST_F(DummyContextTest, globalAdmission) {
DummyConnContext ctx;
ctx.setReadHeader("toto", "titi");
ctx.setReadHeader("tutu", "tata");

// Simulate below the way we extract the clientId from ConnContext
std::string clientId = "NONE";
auto header = ctx.getHeader();
if (header != nullptr) {
auto headers = header->getHeaders();
auto it = headers.find("toto");
if (it != headers.end()) {
clientId = it->second;
}
}

ASSERT_EQ(clientId, "titi");
}

class AdmissionControllerSelectorTest : public testing::Test {
public:
const std::string kClientId{"client_id"};
Expand All @@ -115,29 +61,25 @@ class AdmissionControllerSelectorTest : public testing::Test {
TEST_F(AdmissionControllerSelectorTest, globalAdmission) {
GlobalAdmissionStrategy selector(std::make_shared<DummyController>());

DummyRequest request;
DummyConnContext connContextA1;
connContextA1.setReadHeader(kClientId, "A");
auto admissionControllerA1 =
selector.select("myThriftMethod", request, connContextA1);
THeader headerA1;
headerA1.setReadHeaders({{kClientId, "A"}});
auto admissionControllerA1 = selector.select("myThriftMethod", &headerA1);

DummyConnContext connContextA2;
connContextA2.setReadHeader(kClientId, "A");
auto admissionControllerA2 =
selector.select("myThriftMethod", request, connContextA2);
THeader headerA2;
headerA2.setReadHeaders({{kClientId, "A"}});
auto admissionControllerA2 = selector.select("myThriftMethod", &headerA2);

ASSERT_EQ(admissionControllerA1, admissionControllerA2);

DummyConnContext connContextB1;
connContextB1.setReadHeader(kClientId, "B");
auto admissionControllerB1 =
selector.select("myThriftMethod", request, connContextB1);
THeader headerB1;
headerB1.setReadHeaders({{kClientId, "B"}});
auto admissionControllerB1 = selector.select("myThriftMethod", &headerB1);

ASSERT_EQ(admissionControllerA1, admissionControllerB1);

DummyConnContext connContextNoClientId;
THeader headerNoClientId;
auto admissionControllerNoClientId =
selector.select("myThriftMethod", request, connContextNoClientId);
selector.select("myThriftMethod", &headerNoClientId);

ASSERT_EQ(admissionControllerB1, admissionControllerNoClientId);
}
Expand All @@ -146,30 +88,25 @@ TEST_F(AdmissionControllerSelectorTest, perClientIdAdmission) {
PerClientIdAdmissionStrategy selector(
[](auto&) { return std::make_shared<DummyController>(); }, kClientId);

DummyRequest request;
DummyConnContext connContextA1;
connContextA1.setReadHeader(kClientId, "A");
auto admissionControllerA1 =
selector.select("myThriftMethod", request, connContextA1);
THeader headerA1;
headerA1.setReadHeaders({{kClientId, "A"}});
auto admissionControllerA1 = selector.select("myThriftMethod", &headerA1);

DummyConnContext connContextA2;
connContextA2.setReadHeader(kClientId, "A");
auto admissionControllerA2 =
selector.select("myThriftMethod", request, connContextA2);
THeader headerA2;
headerA2.setReadHeaders({{kClientId, "A"}});
auto admissionControllerA2 = selector.select("myThriftMethod", &headerA2);

ASSERT_EQ(admissionControllerA1, admissionControllerA2);

DummyConnContext connContextB1;
connContextB1.setReadHeader(kClientId, "B");
auto admissionControllerB1 =
selector.select("myThriftMethod", request, connContextB1);
THeader headerB1;
headerB1.setReadHeaders({{kClientId, "B"}});
auto admissionControllerB1 = selector.select("myThriftMethod", &headerB1);

ASSERT_NE(admissionControllerA1, admissionControllerB1);

DummyConnContext connContextB2;
connContextB2.setReadHeader(kClientId, "B");
auto admissionControllerB2 =
selector.select("myThriftMethod", request, connContextB2);
THeader headerB2;
headerB2.setReadHeaders({{kClientId, "B"}});
auto admissionControllerB2 = selector.select("myThriftMethod", &headerB2);

ASSERT_EQ(admissionControllerB1, admissionControllerB2);
}
Expand All @@ -191,10 +128,9 @@ TEST_F(AdmissionControllerSelectorTest, priorityBasedAdmission) {
auto& clientId = it.first;
auto& admControllerSet = it.second;
for (int i = 0; i < 5; i++) {
DummyRequest request;
DummyConnContext connContext;
connContext.setReadHeader(kClientId, clientId);
auto controller = selector.select("myThriftMethod", request, connContext);
THeader header;
header.setReadHeaders({{kClientId, clientId}});
auto controller = selector.select("myThriftMethod", &header);
admControllerSet.insert(controller);
}
}
Expand All @@ -221,20 +157,20 @@ TEST_F(AdmissionControllerSelectorTest, deniesZeroPriority) {
auto& clientId = it.first;
auto& admControllerSet = it.second;
for (int i = 0; i < 5; i++) {
DummyRequest request;
DummyConnContext connContext;
connContext.setReadHeader(kClientId, clientId);
auto controller = selector.select("myThriftMethod", request, connContext);
THeader header;
header.setReadHeaders({{kClientId, clientId}});
auto controller = selector.select("myThriftMethod", &header);
admControllerSet.insert(controller);
}
}
DummyRequest requestC;
DummyConnContext connContextC;
auto controllerForEmpty =
selector.select("myThriftMethod", requestC, connContextC);
connContextC.setReadHeader("client_id", "C");
auto controllerForC =
selector.select("myThriftMethod", requestC, connContextC);

THeader header;
auto controllerForEmpty = selector.select("myThriftMethod", &header);

THeader headerC;
headerC.setReadHeaders({{kClientId, "C"}});
auto controllerForC = selector.select("myThriftMethod", &headerC);

ASSERT_FALSE(controllerForC->admit());
ASSERT_EQ(controllerForEmpty, controllerForC);

Expand All @@ -243,8 +179,6 @@ TEST_F(AdmissionControllerSelectorTest, deniesZeroPriority) {
ASSERT_EQ(mapping["A"].size(), 2);
ASSERT_EQ(mapping["B"].size(), 1);
auto admControllerForB = *mapping["B"].begin();
DummyRequest requestB;
DummyConnContext connContextB;
ASSERT_FALSE(admControllerForB->admit());
}

Expand All @@ -253,15 +187,11 @@ TEST_F(AdmissionControllerSelectorTest, whiteListAdmission) {
WhitelistAdmissionStrategy<GlobalAdmissionStrategy> selector(
whitelist, std::make_shared<DummyController>());

DummyRequest request;
DummyConnContext connContext;
auto admissionController =
selector.select("myThriftMethod", request, connContext);
THeader header;
auto admissionController = selector.select("myThriftMethod", &header);
ASSERT_NE(dynamic_cast<DummyController*>(admissionController.get()), nullptr);

DummyConnContext connContext2;
auto admissionController2 =
selector.select("getStatus", request, connContext2);
auto admissionController2 = selector.select("getStatus", &header);
ASSERT_NE(
dynamic_cast<AcceptAllAdmissionController*>(admissionController2.get()),
nullptr);
Expand Down

0 comments on commit 9443939

Please sign in to comment.