Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

start balancer_raw

  • Loading branch information...
commit c76fd69762622d25d0fe26c4b367b9d5d6b0469c 1 parent 84b682a
@chenshuo authored
View
7 examples/protobuf/rpcbalancer/CMakeLists.txt
@@ -1,4 +1,9 @@
+include_directories(${PROJECT_BINARY_DIR})
+
add_executable(protobuf_rpc_balancer balancer.cc)
set_target_properties(protobuf_rpc_balancer PROPERTIES COMPILE_FLAGS "-Wno-error=shadow")
target_link_libraries(protobuf_rpc_balancer muduo_protorpc)
-include_directories(${PROJECT_BINARY_DIR})
+
+add_executable(protobuf_rpc_balancer_raw balancer_raw.cc)
+set_target_properties(protobuf_rpc_balancer_raw PROPERTIES COMPILE_FLAGS "-Wno-error=shadow")
+target_link_libraries(protobuf_rpc_balancer_raw muduo_protorpc)
View
274 examples/protobuf/rpcbalancer/balancer_raw.cc
@@ -0,0 +1,274 @@
+#include <muduo/base/Logging.h>
+#include <muduo/base/ThreadLocal.h>
+#include <muduo/net/EventLoop.h>
+#include <muduo/net/EventLoopThreadPool.h>
+#include <muduo/net/TcpClient.h>
+#include <muduo/net/TcpServer.h>
+#include <muduo/net/protorpc/RpcCodec.h>
+#include <muduo/net/protorpc/rpc.pb.h>
+
+#include <boost/bind.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <stdio.h>
+
+using namespace muduo;
+using namespace muduo::net;
+
+class BackendSession : boost::noncopyable
+{
+ public:
+ BackendSession(EventLoop* loop, const InetAddress& backendAddr, const string& name)
+ : loop_(loop),
+ client_(loop, backendAddr, name),
+ codec_(boost::bind(&BackendSession::onRpcMessage, this, _1, _2, _3)),
+ nextId_(0)
+ {
+ client_.setConnectionCallback(
+ boost::bind(&BackendSession::onConnection, this, _1));
+ client_.setMessageCallback(
+ boost::bind(&RpcCodec::onMessage, &codec_, _1, _2, _3));
+ client_.enableRetry();
+ }
+
+ void connect()
+ {
+ client_.connect();
+ }
+
+ // FIXME: add health check
+ bool send(RpcMessage& msg, const TcpConnectionPtr& clientConn)
+ {
+ loop_->assertInLoopThread();
+ if (conn_)
+ {
+ uint64_t id = ++nextId_;
+ Request r = { msg.id(), clientConn };
+ assert(outstandings_.find(id) == outstandings_.end());
+ outstandings_[id] = r;
+ msg.set_id(id);
+ codec_.send(conn_, msg);
+ // LOG_DEBUG << "forward " << r.origId << " from " << clientConn->name()
+ // << " as " << id << " to " << conn_->name();
+ return true;
+ }
+ else
+ return false;
+ }
+
+ private:
+ void onConnection(const TcpConnectionPtr& conn)
+ {
+ loop_->assertInLoopThread();
+ LOG_INFO << "Backend "
+ << conn->localAddress().toIpPort() << " -> "
+ << conn->peerAddress().toIpPort() << " is "
+ << (conn->connected() ? "UP" : "DOWN");
+ if (conn->connected())
+ {
+ conn_ = conn;
+ }
+ else
+ {
+ conn_.reset();
+ // FIXME: reject pending
+ }
+ }
+
+ void onRpcMessage(const TcpConnectionPtr&,
+ const RpcMessagePtr& msg,
+ Timestamp)
+ {
+ loop_->assertInLoopThread();
+ std::map<uint64_t, Request>::iterator it = outstandings_.find(msg->id());
+ if (it != outstandings_.end())
+ {
+ uint64_t origId = it->second.origId;
+ TcpConnectionPtr clientConn = it->second.clientConn.lock();
+ outstandings_.erase(it);
+
+ if (clientConn)
+ {
+ // LOG_DEBUG << "send back " << origId << " of " << clientConn->name()
+ // << " using " << msg.id() << " from " << conn_->name();
+ msg->set_id(origId);
+ codec_.send(clientConn, *msg);
+ }
+ }
+ else
+ {
+ // LOG_ERROR
+ }
+ }
+
+ struct Request
+ {
+ uint64_t origId;
+ boost::weak_ptr<TcpConnection> clientConn;
+ };
+
+ EventLoop* loop_;
+ TcpClient client_;
+ RpcCodec codec_;
+ TcpConnectionPtr conn_;
+ uint64_t nextId_;
+ std::map<uint64_t, Request> outstandings_;
+};
+
+class Balancer : boost::noncopyable
+{
+ public:
+ Balancer(EventLoop* loop,
+ const InetAddress& listenAddr,
+ const string& name,
+ const std::vector<InetAddress>& backends)
+ : loop_(loop),
+ server_(loop, listenAddr, name),
+ codec_(RpcCodec::ProtobufMessageCallback(),
+ boost::bind(&Balancer::onRawMessage, this, _1, _2, _3, _4)),
+ backends_(backends)
+ {
+ server_.setThreadInitCallback(
+ boost::bind(&Balancer::initPerThread, this, _1));
+ server_.setConnectionCallback(
+ boost::bind(&Balancer::onConnection, this, _1));
+ server_.setMessageCallback(
+ boost::bind(&RpcCodec::onMessage, &codec_, _1, _2, _3));
+ }
+
+ ~Balancer()
+ {
+ }
+
+ void setThreadNum(int numThreads)
+ {
+ server_.setThreadNum(numThreads);
+ }
+
+ void start()
+ {
+ server_.start();
+ }
+
+ private:
+ struct PerThread
+ {
+ size_t current;
+ boost::ptr_vector<BackendSession> backends;
+ PerThread() : current(0) { }
+ };
+
+ void initPerThread(EventLoop* ioLoop)
+ {
+ int count = threadCount_.getAndAdd(1);
+ LOG_INFO << "IO thread " << count;
+ PerThread& t = t_backends_.value();
+ t.current = count % backends_.size();
+
+ for (size_t i = 0; i < backends_.size(); ++i)
+ {
+ char buf[32];
+ snprintf(buf, sizeof buf, "%s#%d", backends_[i].toIpPort().c_str(), count);
+ t.backends.push_back(new BackendSession(ioLoop, backends_[i], buf));
+ t.backends.back().connect();
+ }
+ }
+
+ void onConnection(const TcpConnectionPtr& conn)
+ {
+ LOG_INFO << "Client "
+ << conn->peerAddress().toIpPort() << " -> "
+ << conn->localAddress().toIpPort() << " is "
+ << (conn->connected() ? "UP" : "DOWN");
+ if (!conn->connected())
+ {
+ // FIXME: cancel outstanding calls, otherwise, memory leak
+ }
+ }
+
+ bool onRawMessage(const TcpConnectionPtr& conn,
+ const char* buf,
+ int len,
+ Timestamp)
+ {
+ if (ProtobufCodecLite::validateChecksum(buf, len)
+ && (memcmp(buf, codec_.tag().data(), codec_.tag().size()) == 0))
+ {
+ LOG_INFO << "Got raw message";
+ // FIXME:
+ }
+ else
+ {
+ // FIXME:
+ }
+ return false;
+ }
+
+ void onRpcMessage(const TcpConnectionPtr& conn,
+ const RpcMessagePtr& msg,
+ Timestamp)
+ {
+ PerThread& t = t_backends_.value();
+ bool succeed = false;
+ for (size_t i = 0; i < t.backends.size(); ++i)
+ {
+ succeed = t.backends[t.current++].send(*msg, conn);
+ if (succeed)
+ {
+ break;
+ }
+ t.current = t.current % t.backends.size();
+ }
+ t.current = t.current % t.backends.size();
+ if (!succeed)
+ {
+ // FIXME: no backend available
+ }
+ }
+
+ EventLoop* loop_;
+ TcpServer server_;
+ RpcCodec codec_;
+ std::vector<InetAddress> backends_;
+ AtomicInt32 threadCount_;
+ ThreadLocal<PerThread> t_backends_;
+};
+
+int main(int argc, char* argv[])
+{
+ LOG_INFO << "pid = " << getpid();
+ if (argc < 3)
+ {
+ fprintf(stderr, "Usage: %s listen_port backend_ip:port [backend_ip:port]\n", argv[0]);
+ }
+ else
+ {
+ std::vector<InetAddress> backends;
+ for (int i = 2; i < argc; ++i)
+ {
+ string hostport = argv[i];
+ size_t colon = hostport.find(':');
+ if (colon != string::npos)
+ {
+ string ip = hostport.substr(0, colon);
+ uint16_t port = static_cast<uint16_t>(atoi(hostport.c_str()+colon+1));
+ backends.push_back(InetAddress(ip, port));
+ }
+ else
+ {
+ fprintf(stderr, "invalid backend address %s\n", argv[i]);
+ return 1;
+ }
+ }
+ uint16_t port = static_cast<uint16_t>(atoi(argv[1]));
+ InetAddress listenAddr(port);
+
+ EventLoop loop;
+ Balancer balancer(&loop, listenAddr, "RpcBalancer", backends);
+ balancer.setThreadNum(4);
+ balancer.start();
+ loop.loop();
+ }
+ google::protobuf::ShutdownProtobufLibrary();
+}
+
View
27 muduo/net/protobuf/ProtobufCodecLite.cc
@@ -81,18 +81,19 @@ void ProtobufCodecLite::onMessage(const TcpConnectionPtr& conn,
errorCallback_(conn, buf, receiveTime, kInvalidLength);
break;
}
- else if (buf->readableBytes() >= implicit_cast<size_t>(len+kChecksumLen))
+ else if (buf->readableBytes() >= implicit_cast<size_t>(len+kHeaderLen))
{
+ if (rawCb_ && !rawCb_(conn, buf->peek()+kHeaderLen, len, receiveTime))
+ {
+ buf->retrieve(kHeaderLen+len);
+ continue;
+ }
MessagePtr message(prototype_->New());
// FIXME: can we move deserialization & callback to other thread?
ErrorCode errorCode = parse(buf->peek()+kHeaderLen, len, message.get());
if (errorCode == kNoError)
{
// FIXME: try { } catch (...) { }
- if (rawCb_)
- {
- rawCb_(conn, buf->peek()+kHeaderLen, len, receiveTime);
- }
messageCallback_(conn, message, receiveTime);
buf->retrieve(kHeaderLen+len);
}
@@ -160,20 +161,24 @@ int32_t ProtobufCodecLite::asInt32(const char* buf)
return sockets::networkToHost32(be32);
}
-ProtobufCodecLite::ErrorCode ProtobufCodecLite::parse(const char* buf,
- int len,
- ::google::protobuf::Message* message)
+bool ProtobufCodecLite::validateChecksum(const char* buf, int len)
{
- ErrorCode error = kNoError;
-
// check sum
int32_t expectedCheckSum = asInt32(buf + len - kChecksumLen);
int32_t checkSum = static_cast<int32_t>(
::adler32(1,
reinterpret_cast<const Bytef*>(buf),
static_cast<int>(len - kChecksumLen)));
+ return checkSum == expectedCheckSum;
+}
+
+ProtobufCodecLite::ErrorCode ProtobufCodecLite::parse(const char* buf,
+ int len,
+ ::google::protobuf::Message* message)
+{
+ ErrorCode error = kNoError;
- if (checkSum == expectedCheckSum)
+ if (validateChecksum(buf, len))
{
if (memcmp(buf, tag_.data(), tag_.size()) == 0)
{
View
20 muduo/net/protobuf/ProtobufCodecLite.h
@@ -66,9 +66,10 @@ class ProtobufCodecLite : boost::noncopyable
kParseError,
};
- typedef boost::function<void (const TcpConnectionPtr&,
+ // return false to stop parsing protobuf message
+ typedef boost::function<bool (const TcpConnectionPtr&,
const char*,
- size_t,
+ int,
Timestamp)> RawMessageCallback;
typedef boost::function<void (const TcpConnectionPtr&,
@@ -81,19 +82,21 @@ class ProtobufCodecLite : boost::noncopyable
ErrorCode)> ErrorCallback;
ProtobufCodecLite(const ::google::protobuf::Message* prototype,
- StringPiece tag,
+ StringPiece tagArg,
const ProtobufMessageCallback& messageCb,
const RawMessageCallback& rawCb = RawMessageCallback(),
const ErrorCallback& errorCb = defaultErrorCallback)
: prototype_(prototype),
- tag_(tag.as_string()),
+ tag_(tagArg.as_string()),
messageCallback_(messageCb),
rawCb_(rawCb),
errorCallback_(errorCb),
- kMinMessageLen(tag.size() + kChecksumLen)
+ kMinMessageLen(tagArg.size() + kChecksumLen)
{
}
+ const string& tag() const { return tag_; }
+
void send(const TcpConnectionPtr& conn,
const ::google::protobuf::Message& message);
@@ -107,6 +110,7 @@ class ProtobufCodecLite : boost::noncopyable
ErrorCode parse(const char* buf, int len, ::google::protobuf::Message* message);
void fillEmptyBuffer(muduo::net::Buffer* buf, const google::protobuf::Message& message);
+ static bool validateChecksum(const char* buf, int len);
static int32_t asInt32(const char* buf);
static void defaultErrorCallback(const TcpConnectionPtr&,
Buffer*,
@@ -134,19 +138,23 @@ class ProtobufCodecLiteT
typedef boost::function<void (const TcpConnectionPtr&,
const ConcreteMessagePtr&,
Timestamp)> ProtobufMessageCallback;
+ typedef ProtobufCodecLite::RawMessageCallback RawMessageCallback;
typedef ProtobufCodecLite::ErrorCallback ErrorCallback;
explicit ProtobufCodecLiteT(const ProtobufMessageCallback& messageCb,
+ const RawMessageCallback& rawCb = RawMessageCallback(),
const ErrorCallback& errorCb = ProtobufCodecLite::defaultErrorCallback)
: messageCallback_(messageCb),
codec_(&MSG::default_instance(),
TAG,
boost::bind(&ProtobufCodecLiteT::onRpcMessage, this, _1, _2, _3),
- ProtobufCodecLite::RawMessageCallback(),
+ rawCb,
errorCb)
{
}
+ const string& tag() const { return codec_.tag(); }
+
void send(const TcpConnectionPtr& conn,
const MSG& message)
{

0 comments on commit c76fd69

Please sign in to comment.
Something went wrong with that request. Please try again.