-
Notifications
You must be signed in to change notification settings - Fork 21
/
rpc.go
113 lines (91 loc) · 2.5 KB
/
rpc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package raft
import (
"context"
"errors"
"fmt"
"github.com/justin0u0/raft/pb"
)
type rpcResponse struct {
resp interface{}
err error
}
type rpc struct {
req interface{}
respCh chan<- *rpcResponse
}
func (r *rpc) respond(resp interface{}, err error) {
r.respCh <- &rpcResponse{resp: resp, err: err}
}
var (
errRPCTimeout = errors.New("rpc timeout")
errResponseTypeMismatch = errors.New("response type mismatch")
errInvalidRPCType = errors.New("invalid rpc type")
errNotLeader = errors.New("not leader")
)
func (r *Raft) ApplyCommand(ctx context.Context, req *pb.ApplyCommandRequest) (*pb.ApplyCommandResponse, error) {
rpcResp, err := r.dispatchRPCRequest(ctx, req)
if err != nil {
return nil, err
}
resp, ok := rpcResp.(*pb.ApplyCommandResponse)
if !ok {
return nil, errResponseTypeMismatch
}
if err := r.saveRaftState(r.persister); err != nil {
return nil, fmt.Errorf("fail to save raft state: %w", err)
}
return resp, nil
}
func (r *Raft) AppendEntries(ctx context.Context, req *pb.AppendEntriesRequest) (*pb.AppendEntriesResponse, error) {
rpcResp, err := r.dispatchRPCRequest(ctx, req)
if err != nil {
return nil, err
}
resp, ok := rpcResp.(*pb.AppendEntriesResponse)
if !ok {
return nil, errResponseTypeMismatch
}
if err := r.saveRaftState(r.persister); err != nil {
return nil, fmt.Errorf("fail to save raft state: %w", err)
}
return resp, nil
}
func (r *Raft) RequestVote(ctx context.Context, req *pb.RequestVoteRequest) (*pb.RequestVoteResponse, error) {
rpcResp, err := r.dispatchRPCRequest(ctx, req)
if err != nil {
return nil, err
}
resp, ok := rpcResp.(*pb.RequestVoteResponse)
if !ok {
return nil, errResponseTypeMismatch
}
if err := r.saveRaftState(r.persister); err != nil {
return nil, fmt.Errorf("fail to save raft state: %w", err)
}
return resp, nil
}
func (r *Raft) dispatchRPCRequest(ctx context.Context, req interface{}) (interface{}, error) {
respCh := make(chan *rpcResponse, 1)
r.rpcCh <- &rpc{req: req, respCh: respCh}
select {
case <-ctx.Done():
return nil, errRPCTimeout
case rpcResp := <-respCh:
if err := rpcResp.err; err != nil {
return nil, err
}
return rpcResp.resp, nil
}
}
func (r *Raft) handleRPCRequest(rpc *rpc) {
switch req := rpc.req.(type) {
case *pb.ApplyCommandRequest:
rpc.respond(r.applyCommand(req))
case *pb.AppendEntriesRequest:
rpc.respond(r.appendEntries(req))
case *pb.RequestVoteRequest:
rpc.respond(r.requestVote(req))
default:
rpc.respond(nil, errInvalidRPCType)
}
}