/* Mathieu Stefani, 07 février 2016 Example of a REST endpoint with routing Modified to use Client to forward a request */ #include "http.h" #include "router.h" #include "endpoint.h" #include "client.h" #include using namespace std; using namespace Net; void printCookies(const Net::Http::Request& req) { auto cookies = req.cookies(); std::cout << "Cookies: [" << std::endl; const std::string indent(4, ' '); for (const auto& c: cookies) { std::cout << indent << c.name << " = " << c.value << std::endl; } std::cout << "]" << std::endl; } namespace Generic { void handleReady(const Rest::Request&, Http::ResponseWriter response) { response.send(Http::Code::Ok, "1"); } } class StatsEndpoint { public: StatsEndpoint(Net::Address addr) : httpEndpoint(std::make_shared(addr)) { } void init(size_t thr = 2) { auto opts = Net::Http::Endpoint::options() .threads(thr) .flags(Net::Tcp::Options::InstallSignalHandler); httpEndpoint->init(opts); setupRoutes(); } void start() { httpEndpoint->setHandler(router.handler()); httpEndpoint->serve(); } void shutdown() { httpEndpoint->shutdown(); } void setForwardAddress(const std::string& address) { forwardAddress = address; } private: void setupRoutes() { using namespace Net::Rest; Routes::Post(router, "/record/:name/:value?", Routes::bind(&StatsEndpoint::doRecordMetric, this)); Routes::Get(router, "/value/:name", Routes::bind(&StatsEndpoint::doGetMetric, this)); Routes::Get(router, "/ready", Routes::bind(&Generic::handleReady)); Routes::Get(router, "/auth", Routes::bind(&StatsEndpoint::doAuth, this)); Routes::Get(router, "/forward", Routes::bind(&StatsEndpoint::doForward, this)); } void doRecordMetric(const Rest::Request& request, Net::Http::ResponseWriter response) { auto name = request.param(":name").as(); Guard guard(metricsLock); auto it = std::find_if(metrics.begin(), metrics.end(), [&](const Metric& metric) { return metric.name() == name; }); int val = 1; if (request.hasParam(":value")) { auto value = request.param(":value"); val = value.as(); } if (it == std::end(metrics)) { metrics.push_back(Metric(std::move(name), val)); response.send(Http::Code::Created, std::to_string(val)); } else { auto &metric = *it; metric.incr(val); response.send(Http::Code::Ok, std::to_string(metric.value())); } } void doGetMetric(const Rest::Request& request, Net::Http::ResponseWriter response) { auto name = request.param(":name").as(); Guard guard(metricsLock); auto it = std::find_if(metrics.begin(), metrics.end(), [&](const Metric& metric) { return metric.name() == name; }); if (it == std::end(metrics)) { response.send(Http::Code::Not_Found, "Metric does not exist"); } else { const auto& metric = *it; response.send(Http::Code::Ok, std::to_string(metric.value())); } } void doAuth(const Rest::Request& request, Net::Http::ResponseWriter response) { printCookies(request); response.cookies() .add(Http::Cookie("lang", "en-US")); response.send(Http::Code::Ok); } void doForward(const Rest::Request& request, Net::Http::ResponseWriter response) { if(forwardAddress.empty()) { response.send(Http::Code::Method_Not_Allowed, "A forward address must be specified to handle forward calls"); return; } Http::Client client; auto clientOpts = Http::Client::options().threads(1).maxConnectionsPerHost(2); client.init(clientOpts); std::vector> responsePromises; responsePromises.push_back(client.get(forwardAddress).body(request.body()).send()); Http::Code responseCode = Http::Code::Gateway_Timeout; std::string responseBody("Forwarding request failed\n"); responsePromises[0].then( [&](Http::Response clientResponse) { responseCode = clientResponse.code(); responseBody = clientResponse.body(); }, [&](std::exception_ptr exception) { // This is never invoked, see GitHub issue #37. try { std::rethrow_exception(exception); } catch (const std::exception& exc) { std::cerr << "Failed to forward request: %s", exc.what(); } } ); // Wait for the client response auto sync = Async::whenAll(responsePromises.begin(), responsePromises.end()); Async::Barrier> barrier(sync); barrier.wait_for(std::chrono::seconds(5)); client.shutdown(); response.send(responseCode, responseBody); } class Metric { public: Metric(std::string name, int initialValue = 1) : name_(std::move(name)) , value_(initialValue) { } int incr(int n = 1) { int old = value_; value_ += n; return old; } int value() const { return value_; } std::string name() const { return name_; } private: std::string name_; int value_; }; typedef std::mutex Lock; typedef std::lock_guard Guard; Lock metricsLock; std::vector metrics; std::shared_ptr httpEndpoint; Rest::Router router; std::string forwardAddress; }; int main(int argc, char *argv[]) { Net::Port port(9080); int thr = 2; if (argc >= 2) { port = std::stol(argv[1]); if (argc >= 3) thr = std::stol(argv[2]); } Net::Address addr(Net::Ipv4::any(), port); cout << "Cores = " << hardware_concurrency() << endl; cout << "Using " << thr << " threads" << endl; StatsEndpoint stats(addr); if (argc >= 4) stats.setForwardAddress(argv[3]); stats.init(thr); stats.start(); stats.shutdown(); }