Skip to content
Browse files

add rpc and cliet dir

  • Loading branch information...
1 parent a581315 commit 5db804ea18bec15d96febef4aa3fecebfce9fe32 @notedit notedit committed
Showing with 1,055 additions and 0 deletions.
  1. +214 −0 client/client.py
  2. +121 −0 rpc/all_test.go
  3. +234 −0 rpc/client.go
  4. +486 −0 rpc/server.go
View
214 client/client.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+# date: 2012-02-20
+# author: notedit<notedit@gmail.com>
+
+import socket
+import struct
+import bson
+from bson import BSON
+from Queue import Queue
+
+"""
+a simple connection pool
+"""
+
+SOCKET_TIMEOUT = 10.0
+
+
+
+class BackendError(Exception):
+
+ def __init__(self,message,detail):
+ self.message = message
+ self.detail = detail
+
+ def __str__(self):
+ return 'BackendError(%s,%s)' % (self.message,self.detail)
+
+ def __repr__(self):
+ return 'BakcneError(%s,%s)' % (self.message,self.detail)
+
+
+class RpcError(Exception):
+
+ def __init__(self,message,detail):
+ self.message = message
+ self.detail = detail
+
+ def __str__(self):
+ return '%s:%s' % (self.message,self.detail)
+
+ def __repr__(self):
+ return '%s,%s' % (self.message,self.detail)
+
+class Connection(object):
+
+ def __init__(self,host='localhost',port=9090):
+ self.host = host
+ self.port = port
+ self._sock = None
+
+ def __del__(self):
+ try:
+ self.disconnect()
+ except:
+ pass
+
+ def connect(self):
+ if self._sock:
+ return
+ try:
+ sock = self._connect()
+ except socket.timeout,e:
+ raise RpcError('ConnTimeoutError',str(e))
+ except socket.error,e:
+ raise RpcError('ConntionError',str(e))
+ self._sock = sock
+
+ def _connect(self):
+ sock = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
+ sock.settimeout(SOCKET_TIMEOUT)
+ sock.connect((self.host,self.port))
+ return sock
+
+ def disconnect(self):
+ if self._sock is None:
+ return
+ try:
+ self._sock.close()
+ except socket.error:
+ pass
+ self._sock = None
+
+ def write_request(self,method,args):
+ if not self._sock:
+ self.connect()
+ try:
+ data = self.encode(method,args)
+ self._sock.sendall(data)
+ except:
+ self.disconnect()
+ raise
+
+ def read_request(self):
+ try:
+ m = self._sock.recv(4)
+ while len(m) < 4:
+ m += self._sock.recv(4-len(m))
+ lt = struct.unpack('>i',m)[0]
+ ret = self._sock.recv(lt)
+ while len(ret) < lt:
+ ret += self._sock.recv(lt - len(ret))
+ return self.decode(ret)
+ except:
+ self.disconnect()
+ raise
+
+ def decode(self,data):
+ ret = bson.BSON(data).decode()
+ return ret
+
+ def encode(self,method,args):
+ cdict = {'operation':1,'method':method,'argument':args}
+ cs = bson.BSON.encode(cdict)
+ msghead = struct.pack('>i',len(cs))
+ return msghead + cs
+
+ def call(self,method,args):
+ try:
+ self.write_request(method,args)
+ ret = self.read_request()
+ if ret.get('operation',None) == 2:
+ return ret['reply'] or None
+ elif ret.get('operation',None) == 3:
+ raise BackendError(ret['reply']['message'],ret['reply']['detail'])
+ else:
+ raise BackendError('InternalError','unvalid response')
+ except BackendError,err:
+ raise
+ except Exception,err:
+ raise BackendError('InternalError',str(err))
+
+
+class ConnectionPool(object):
+ """Generic connection pool"""
+ def __init__(self,host='localhost',port=9090,max_connection=10):
+ self.host = host
+ self.port = port
+ self.max_connection = max_connection
+ self._connections = self.initconns()
+
+ def initconns(self):
+ conns = Queue(self.max_connection)
+ #for x in xrange(self.max_connection):
+ conns.put(None)
+ return conns
+
+ def get_connection(self):
+ cn = self._connections.get(True,1) # set the timeout 1 second
+ if cn is None:
+ cn = Connection(self.host,self.port)
+ return cn
+
+ def release(self,cn):
+ if not self._connections.full():
+ self._connections.put(cn)
+ else:
+ cn.disconnect()
+
+class Service(object):
+ """rpc service"""
+ def __init__(self,sname,cliet):
+ self.sname = sname
+ self.client = client
+ self.mdict = {} # a dict
+
+ def __getattr__(self,method):
+ sm = '%s.%s'% (self.sname,method)
+ if self.mdict.get(sm,None) is None:
+ me = lambda args: self.__call__(sm,args)
+ me.__name__ = sm
+ self.mdict[sm] = me
+ return me
+ else:
+ return self.mdict.get(sm)
+
+ def __call__(self,method,args):
+ #serviceMethod
+ conn = self.client.pool.get_connection()
+ try:
+ ret = conn.call(method,args)
+ except BackendError,ex:
+ self.client.pool.release(conn)
+ raise
+ self.client.pool.release(conn)
+ return ret
+
+
+class RpcClient(object):
+ """rpc client"""
+ sdict = {}
+ pool = None
+ def __init__(self,host='localhost',port=9090):
+ self.host = host
+ self.port = port
+ self.pool = ConnectionPool(host,port)
+
+ def __getattr__(self,service):
+ if self.sdict.get(service,None) is None:
+ ser = Service(service,self)
+ self.sdict[service] = ser
+ return ser
+ else:
+ return self.sdict.get(service)
+
+
+
+
+if __name__ == '__main__':
+ client = RpcClient(host='localhost',port=9090)
+ ret = client.Arith.Add({'a':7,'b':8})
+ print 'Arith.Add',ret
+
+ ret = client.Arith.Mul({'a':7,'b':8})
+ print 'Arith.Mul',ret
View
121 rpc/all_test.go
@@ -0,0 +1,121 @@
+// Date: 2012-02-16
+// Author: notedit<notedit@gmail.com>
+
+package rpc
+
+import (
+ "fmt"
+ "errors"
+ "testing"
+)
+
+
+type Args struct {
+ A,B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+func (t *Arith) Add(args *Args,reply *Reply) error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args,reply *Reply) error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args *Args,reply *Reply) error {
+ if args.B == 0 {
+ return BackendError{"InternalError","divide by zero"}
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) NError(args *Args,reply *Reply) error {
+ return errors.New("normalerror")
+}
+
+func (t *Arith) Error(args *Args,reply *Reply) error {
+ panic("ERROR")
+}
+
+func startServer() {
+ newServer := NewServer("localhost",9091)
+ newServer.Register(new(Arith))
+ newServer.Serv()
+}
+
+func TestServer(t *testing.T) {
+ go startServer()
+ client := New("localhost:9091")
+
+ // normal calls
+ args := &Args{7,8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add",args,reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q",err.Error())
+ }
+ if reply.C != args.A + args.B {
+ t.Errorf("Add: expected %d got %d",reply.C,args.A + args.B)
+ }
+
+ // Nonexistent method
+ args = &Args{7,0}
+ reply = new(Reply)
+ err = client.Call("Arith.BadOperation",args,reply)
+ if err == nil {
+ t.Error("BadOperation: expected errpor")
+ } else if err.(BackendError).Message != "InternalError" {
+ fmt.Printf("%#v\n",err)
+ t.Errorf("BadOperation: expected can't find method error")
+ }
+
+ // normal error
+
+ err = client.Call("Arith.NError",args,reply)
+ if err == nil {
+ t.Error("expected normal error")
+ } else if err.(BackendError).Detail != "normalerror" {
+ fmt.Println(err)
+ t.Errorf("error detail will be normalerror, %v\n",err)
+ }
+
+ // Unknown service
+ args = &Args{7,8}
+ reply = new(Reply)
+ err = client.Call("Unknow.Arith",args,reply)
+ if err == nil {
+ t.Error("expected Unknow service error")
+ } else if err.(BackendError).Message != "InternalError" {
+ t.Error("error message will be InternalError ")
+ }
+
+
+ // Error test
+ args = &Args{7,0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div",args,reply)
+
+ fmt.Printf("%#v\n",err)
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.(BackendError).Detail != "divide by zero" {
+ t.Error("expected divide by zero error detail")
+ }
+
+ // Panic test
+ err = client.Call("Arith.Error",args,reply)
+ if err == nil {
+ t.Error("expect panic error")
+ } else if err.(BackendError).Detail != "ERROR" {
+ t.Error("Panic test expect ERROR detail")
+ }
+}
View
234 rpc/client.go
@@ -0,0 +1,234 @@
+// Date: 2012-02-14
+// Author: notedit<notedit@gmail.com>
+
+package rpc
+
+import (
+ "log"
+ "io"
+ "net"
+ "io/ioutil"
+ "bufio"
+ "sync"
+ "time"
+ "errors"
+ "encoding/binary"
+ "launchpad.net/mgo/bson"
+)
+
+// timeout
+const DefaultTimeout = time.Duration(10000) * time.Millisecond
+// connection pool number
+const DefaultConnectionPool = 10
+
+
+
+type Client struct {
+ addr net.Addr
+ seq uint32
+ mutex sync.Mutex
+ Timeout time.Duration
+ freeconn []*conn
+}
+
+type conn struct {
+ cn net.Conn
+ rw *bufio.ReadWriter
+ c *Client
+}
+
+func (cn *conn)WriteRequest(req *clientRequest)(err error) {
+ bys,err := bson.Marshal(req)
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ req.messageLength = uint32(len(bys))
+ // write message header
+ rw := cn.rw.Writer
+ _,err = rw.Write([]byte{byte(req.messageLength>>24),byte(req.messageLength>>16),byte(req.messageLength>>8),byte(req.messageLength)})
+ if err != nil {
+ log.Println("write requestHeader error:",err)
+ return
+ }
+ _,err = rw.Write(bys)
+ if err != nil {
+ return
+ }
+ if err = rw.Flush(); err != nil {
+ log.Println("write requestBody error:",err)
+ }
+ return
+}
+
+func (cn *conn)ReadResponse(res *clientResponse) (err error){
+ if err = cn.ReadResponseHeader(res); err != nil {
+ return
+ }
+ err = cn.ReadResponseBody(res)
+ return
+}
+
+func (cn *conn)ReadResponseHeader(res *clientResponse) (err error) {
+ msgheader := make([]byte,4)
+ _,err = cn.rw.Read(msgheader)
+ if err != nil {
+ res = nil
+ if err == io.EOF {
+ return errors.New("rpc: client cannot read requestHeader" + err.Error())
+ }
+ err = errors.New("rpc: client cannot read requestHeader " + err.Error())
+ return
+ }
+ res.messageLength = binary.BigEndian.Uint32(msgheader)
+ return nil
+}
+
+func (cn *conn)ReadResponseBody(res *clientResponse) (err error) {
+ msgbody,err := ioutil.ReadAll(io.LimitReader(cn.rw.Reader,int64(res.messageLength)))
+ if err != nil {
+ err = errors.New("rpc: client cannot read requestBody "+ err.Error())
+ return
+ }
+ if err = bson.Unmarshal(msgbody,res); err != nil {
+ return
+ }
+ return
+}
+
+type clientRequest struct {
+ messageLength uint32 // unexported element will not be marshaled
+ Operation uint8
+ Method string
+ Argument interface{}
+}
+
+type clientResponse struct {
+ messageLength uint32
+ Operation uint8
+ Reply bson.Raw
+}
+
+func New(server string) *Client {
+ addr,err := net.ResolveTCPAddr("tcp",server)
+ if err != nil {
+ panic(err)
+ }
+ return &Client{addr:addr,freeconn:make([]*conn,0,DefaultConnectionPool)}
+}
+
+func (c *Client)dial()(net.Conn,error) {
+ cn,err := net.Dial(c.addr.Network(),c.addr.String())
+ if err != nil {
+ return nil,err
+ }
+ return cn,nil
+}
+
+func (c *Client)getConn()(*conn,error) {
+ cn,ok := c.getFreeConn()
+ if ok {
+ return cn,nil
+ }
+
+
+ nc,err := c.dial()
+ if err != nil {
+ return nil,err
+ }
+
+ return &conn{
+ cn: nc,
+ rw: bufio.NewReadWriter(bufio.NewReader(nc),bufio.NewWriter(nc)),
+ c: c,
+ }, nil
+}
+
+func (c *Client)getFreeConn()(*conn,bool) {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ if len(c.freeconn) == 0 {
+ return nil,false
+ }
+ cn := c.freeconn[len(c.freeconn)-1]
+ c.freeconn = c.freeconn[:len(c.freeconn)-1]
+ return cn,true
+}
+
+func (c *Client)release(cn *conn) {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ if len(c.freeconn) >= DefaultConnectionPool {
+ cn.cn.Close()
+ return
+ }
+ c.freeconn = append(c.freeconn,cn)
+}
+
+func (c *Client)call(req *clientRequest,reply interface{}) (err error) {
+ cn,err := c.getConn()
+ if err != nil {
+ return
+ }
+ defer func(){
+ if cn != nil {
+ c.release(cn)
+ }
+ switch err.(type) {
+ case BackendError:
+ case error:
+ err = BackendError{Message:"ClientError",Detail:err.Error()}
+ default:
+ }
+ }()
+ if err = cn.WriteRequest(req); err != nil {
+ return err
+ }
+ res := &clientResponse{}
+ if err = cn.ReadResponse(res); err != nil {
+ return
+ }
+ if err = parseReply(res,reply); err != nil {
+ return
+ }
+ return nil
+}
+
+func parseReply(res *clientResponse,reply interface{}) error {
+ if res.Operation == 2 {
+ // valid reply
+ err := res.Reply.Unmarshal(reply)
+ //if err != nil {
+ // e.Message = "UnvalidUnmarshalError"
+ // e.Detail = err.Error()
+ // return e
+ //}
+ return err
+ } else if res.Operation == 3 {
+ // error reply
+ e := &BackendError{}
+ err := res.Reply.Unmarshal(e)
+ if err != nil {
+ e.Message = "ClientUnmarshalError"
+ e.Detail = err.Error()
+ }
+ return *e
+ }
+ return BackendError{
+ Message:"UnvalidOperationError",
+ Detail:"unvalid oparation error, it may be 2 or 3",
+ }
+}
+
+func (c *Client)Call(serviceMethod string,args interface{},reply interface{}) error {
+ req := new(clientRequest)
+ req.Method = serviceMethod
+ req.Argument = args
+ req.Operation = uint8(1)
+ err := c.call(req,reply)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
View
486 rpc/server.go
@@ -0,0 +1,486 @@
+// Date: 2012-02-08
+// Author: notedit <notedit@gmail.com>
+// make a go rpc service
+
+package rpc
+
+import (
+ "fmt"
+ "log"
+ "net"
+ "io"
+ "io/ioutil"
+ "encoding/binary"
+ "bufio"
+ "sync"
+ "reflect"
+ "errors"
+ "strings"
+ "unicode"
+ // "runtime"
+ "unicode/utf8"
+ "launchpad.net/mgo/bson"
+)
+
+var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
+
+type methodType struct {
+ method reflect.Method
+ ArgType reflect.Type
+ ReplyType reflect.Type
+}
+
+type service struct {
+ name string
+ rcvr reflect.Value
+ typ reflect.Type
+ method map[string]*methodType
+}
+
+// rpc server
+type Server struct {
+ mu sync.Mutex
+ serviceMap map[string]*service
+ listener *net.TCPListener
+ reqLock sync.Mutex
+ freeReq *serverRequest
+ respLock sync.Mutex
+ freeResp *serverResponse
+}
+
+
+//BackendError
+type BackendError struct {
+ Message string
+ Detail string
+}
+
+func (e BackendError) Error() string {
+ return fmt.Sprintf("%s:%s",e.Message,e.Detail)
+}
+
+// operation has three values -- call:1 reply:2 error:3
+
+// request
+type serverRequest struct {
+ messageLength uint32 // unexported
+ next *serverRequest // unexported
+ Operation uint8
+ Method string
+ Argument bson.Raw
+}
+
+// response
+type serverResponse struct {
+ messageLength uint32 // unexported
+ next *serverResponse // unexported
+ Operation uint8
+ Reply interface{}
+}
+
+// decode request and encode response
+type ServerCodec struct {
+ cn net.Conn
+ rw *bufio.ReadWriter
+}
+
+// read the message header
+func (c *ServerCodec)ReadRequestHeader(req *serverRequest) (err error) {
+ msgheader := make([]byte,4)
+ _,err = c.rw.Read(msgheader)
+ if err != nil {
+ req = nil
+ if err == io.EOF {
+ return
+ }
+ err = errors.New("rpc: server cannot decode requestheader: " + err.Error())
+ return
+ }
+ req.messageLength = binary.BigEndian.Uint32(msgheader)
+ return nil
+}
+
+func (c *ServerCodec)ReadRequestBody(req *serverRequest) (err error) {
+ msgbytes,err := ioutil.ReadAll(io.LimitReader(c.rw.Reader,int64(req.messageLength)))
+ if err != nil {
+ if err == io.EOF {
+ return
+ }
+ err = errors.New("rpc: server cannot read full requestBody: " + err.Error())
+ return
+ }
+ if err = bson.Unmarshal(msgbytes,req); err != nil {
+ return
+ }
+ return
+}
+
+func (c *ServerCodec)WriteResponse(res *serverResponse) (err error) {
+ bys, err := bson.Marshal(res)
+ if err != nil {
+ log.Println("writeresponse error",err)
+ return
+ }
+ res.messageLength = uint32(len(bys))
+ // write message header
+ rw := c.rw.Writer
+ _,err = rw.Write([]byte{byte(res.messageLength>>24),byte(res.messageLength>>16),byte(res.messageLength>>8),byte(res.messageLength)})
+ if err != nil {
+ log.Println("write responseHeader error",err)
+ return
+ }
+ // write message body
+ _,err = rw.Write(bys)
+ if err != nil {
+ log.Println("write responseBody error",err)
+ return
+ }
+ if err = rw.Flush(); err != nil {
+ log.Println("flush responseBody error",err)
+ }
+ return
+}
+
+// todo
+func (c *ServerCodec)Close() error {
+ return c.cn.Close()
+}
+
+// Is this an exported - upper case
+func isExported(name string) bool {
+ rune,_ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// Is this typoe exported or a builtin?
+func isExportedOrBuiltinType(t reflect.Type) bool {
+ for t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ }
+
+ return isExported(t.Name()) || t.PkgPath() == ""
+}
+
+// Register a service
+func (server *Server)Register(rcvr interface{}) error {
+ return server.register(rcvr,"",false)
+}
+
+// Register a sevice with a name
+func (server *Server)RegisterName(name string,rcvr interface{}) error {
+ return server.register(rcvr,name,true)
+}
+
+// the real register
+func (server *Server)register(rcvr interface{}, name string,useName bool) error {
+ server.mu.Lock()
+ defer server.mu.Unlock()
+ if server.serviceMap == nil {
+ server.serviceMap = make(map[string]*service)
+ }
+ s := new(service)
+ s.typ = reflect.TypeOf(rcvr)
+ s.rcvr = reflect.ValueOf(rcvr)
+ sname := reflect.Indirect(s.rcvr).Type().Name()
+ if useName {
+ sname = name
+ }
+ if sname == "" {
+ log.Fatal("rpc: no service name for type",s.typ.String())
+ }
+ if !isExported(sname) && !useName {
+ s := "rpc Register: type " + sname + " is not exported"
+ log.Print(s)
+ return errors.New(s)
+ }
+ if _,present := server.serviceMap[sname]; present {
+ return errors.New("rpc: service already defined: " + sname)
+ }
+ s.name = sname
+ s.method = make(map[string]*methodType)
+
+ // Install the methods
+ for m:=0; m < s.typ.NumMethod(); m++ {
+ method := s.typ.Method(m)
+ mtype := method.Type
+ mname := method.Name
+ if method.PkgPath != "" {
+ fmt.Println(method.PkgPath)
+ continue
+ }
+
+ //Method needs three ins
+ if mtype.NumIn() != 3 {
+ log.Println("method needs three ins")
+ continue
+ }
+
+ // Method has one out:error
+ if mtype.NumOut() != 1 {
+ log.Println("method",mname,"has wrong number of outs:",mtype.NumOut())
+ continue
+ }
+
+ // first arg need not be a pointer
+ argType := mtype.In(1)
+ if !isExportedOrBuiltinType(argType) {
+ log.Println(mname,"argument type not exported or local",argType)
+ continue
+ }
+
+ replyType := mtype.In(2)
+ if replyType.Kind() != reflect.Ptr {
+ log.Println("method",mname," reply type not a pointer:",replyType)
+ continue
+ }
+
+ if !isExportedOrBuiltinType(replyType) {
+ log.Println("method ",mname,"reply type not exported or local",replyType)
+ continue
+ }
+
+ // error type
+ if returnType := mtype.Out(0); returnType != typeOfError {
+ log.Println("method",mname," returns",returnType.String(),"not error")
+ continue
+ }
+
+ s.method[mname] = &methodType{method:method,ArgType:argType,ReplyType:replyType}
+ }
+
+ if len(s.method) == 0 {
+ s := "rpc Register: type " + sname + " has no exported methods of suitable type"
+ log.Println(s)
+ return errors.New(s)
+ }
+ server.serviceMap[s.name] = s
+ return nil
+}
+
+
+
+func NewServer(host string,port uint) *Server {
+ addr,err := net.ResolveTCPAddr("tcp",fmt.Sprintf("%s:%d",host,port))
+ if err != nil {
+ log.Fatal("rpc error:",err.Error());
+ }
+ listener,err := net.ListenTCP("tcp",addr)
+ if err != nil {
+ log.Fatal("rpc error:",err.Error())
+ }
+ return &Server{
+ serviceMap:make(map[string]*service),
+ listener:listener,
+ }
+}
+
+
+// request and response pool
+
+func (server *Server) getRequest() *serverRequest {
+ server.reqLock.Lock()
+ req := server.freeReq
+ if req == nil {
+ req = new(serverRequest)
+ } else {
+ server.freeReq = req.next
+ *req = serverRequest{}
+ }
+ server.reqLock.Unlock()
+ return req
+}
+
+func (server *Server) freeRequest(req *serverRequest) {
+ server.reqLock.Lock()
+ req.next = server.freeReq
+ server.freeReq = req
+ server.reqLock.Unlock()
+}
+
+func (server *Server) getResponse() *serverResponse {
+ server.respLock.Lock()
+ resp := server.freeResp
+ if resp == nil {
+ resp = new(serverResponse)
+ } else {
+ server.freeResp = resp.next
+ *resp = serverResponse{}
+ }
+ server.respLock.Unlock()
+ return resp
+}
+
+func (server *Server) freeResponse(resp *serverResponse) {
+ server.respLock.Lock()
+ resp.next = server.freeResp
+ server.freeResp = resp
+ server.respLock.Unlock()
+}
+
+// serv
+func (server *Server) Serv() {
+
+ for{
+ c,err := server.listener.Accept()
+ if err != nil {
+ log.Print("rpc:",err.Error())
+ continue
+ }
+ go server.ServeConn(c)
+ }
+
+}
+
+func (server *Server) ServeConn(conn net.Conn){
+ src := &ServerCodec{
+ cn:conn,
+ rw:bufio.NewReadWriter(bufio.NewReader(conn),bufio.NewWriter(conn)),
+ }
+ server.ServeCodec(src)
+}
+
+func (server *Server)ServeCodec(codec *ServerCodec) {
+ sending := new(sync.Mutex)
+ for {
+ service,mtype,req,argv,replyv,keepReading,err := server.readRequest(codec)
+ if err != nil {
+ if err != io.EOF {
+ log.Println(err)
+ }
+ if !keepReading {
+ break
+ }
+ // we just got the req
+ if req != nil {
+ server.sendResponse(nil,req,codec,err,sending)
+ server.freeRequest(req)
+ }
+ continue
+ }
+ go service.call(server,mtype,req,argv,replyv,codec,sending)
+ }
+ // to do some recover
+ codec.Close()
+}
+
+func (server *Server)readRequest(codec *ServerCodec) (service *service,mtype *methodType,req *serverRequest,argv reflect.Value,replyv reflect.Value,keepReading bool,err error){
+ req,keepReading,err = server.readRequestHeader(codec)
+ if err != nil {
+ return
+ }
+ service,mtype,argv,replyv,keepReading,err = server.readRequestBody(codec,req)
+ return
+}
+
+func (server *Server)readRequestBody(codec *ServerCodec,req *serverRequest) (service *service,mtype *methodType, argv reflect.Value,replyv reflect.Value,keepReading bool,err error){
+ err = codec.ReadRequestBody(req)
+ if err != nil {
+ return
+ }
+ // funcname'format -- service.method
+
+ keepReading = true
+ serviceMethod := strings.Split(req.Method,".")
+ if len(serviceMethod) != 2 {
+ err = errors.New("rpc: service/method request ill-formed: " + req.Method)
+ return
+ }
+ // look up the service
+ server.mu.Lock()
+ service = server.serviceMap[serviceMethod[0]]
+ server.mu.Unlock()
+ if service == nil {
+ err = errors.New("rpc: can't find service " + serviceMethod[0])
+ return
+ }
+ // look up the method
+ mtype = service.method[serviceMethod[1]]
+ if mtype == nil {
+ err = errors.New("rpc: can't find method " + serviceMethod[1])
+ return
+ }
+
+ argIsValue := false
+ if mtype.ArgType.Kind() == reflect.Ptr {
+ argv = reflect.New(mtype.ArgType.Elem())
+ } else {
+ argv = reflect.New(mtype.ArgType)
+ argIsValue = true
+ }
+
+ //argv now is a pointer now
+ if err = req.Argument.Unmarshal(argv.Interface()); err != nil {
+ return
+ }
+
+ if argIsValue {
+ argv = argv.Elem()
+ }
+
+ replyv = reflect.New(mtype.ReplyType.Elem())
+ return
+}
+
+func (server *Server)readRequestHeader(codec *ServerCodec) (req *serverRequest,keepReading bool,err error){
+ req = server.getRequest()
+ err = codec.ReadRequestHeader(req)
+ if err != nil {
+ req = nil
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ return
+ }
+ err = errors.New("rpc: server cannot decode the requestheader: " + err.Error())
+ }
+ return
+}
+
+func (server *Server)sendResponse(reply interface{},req *serverRequest,codec *ServerCodec,err interface{},sending *sync.Mutex) {
+ var rerr error
+ res := server.getResponse()
+ switch err.(type) {
+ case nil:
+ res.Operation = uint8(2)
+ res.Reply = reply
+ case BackendError:
+ res.Operation = uint8(3)
+ res.Reply = err
+ case error:
+ res.Operation = uint8(3)
+ res.Reply = BackendError{Message:"InternalError",Detail:err.(error).Error()}
+ default:
+ res.Operation = uint8(3)
+ res.Reply = BackendError{Message:"InternalError",Detail:"error is unvalid"}
+ }
+ sending.Lock()
+ rerr = codec.WriteResponse(res)
+ if rerr != nil {
+ log.Println("rpc error:",rerr)
+ }
+ sending.Unlock()
+ server.freeResponse(res)
+}
+
+
+// run the service.method
+func (s *service) call(server *Server,mtype *methodType,req *serverRequest,argv,replyv reflect.Value, codec *ServerCodec, sending *sync.Mutex) {
+ defer func(){
+ // it may be panic in the method
+ if r := recover(); r != nil {
+ err := errors.New(fmt.Sprint(r))
+ server.sendResponse(nil,req,codec,err,sending)
+ server.freeRequest(req)
+ }
+ }()
+ function := mtype.method.Func
+ returnValues := function.Call([]reflect.Value{s.rcvr,argv,replyv})
+ err := returnValues[0].Interface()
+ server.sendResponse(replyv.Interface(),req,codec,err,sending)
+ server.freeRequest(req)
+}
+
+
+
+
+//////////////////////////////////////////////////////////////////////
+// some test
+

0 comments on commit 5db804e

Please sign in to comment.
Something went wrong with that request. Please try again.