forked from google/zoekt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rpc.go
149 lines (131 loc) · 3.96 KB
/
rpc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Package rpc provides a zoekt.Searcher over RPC.
package rpc
import (
"context"
"encoding/gob"
"fmt"
"net/http"
"sync"
"time"
"github.com/henrik242/zoekt"
"github.com/henrik242/zoekt/query"
"github.com/henrik242/zoekt/rpc/internal/srv"
"github.com/keegancsmith/rpc"
)
// DefaultRPCPath is the rpc path used by zoekt-webserver
const DefaultRPCPath = "/rpc"
// Server returns an http.Handler for searcher which is the server side of the
// RPC calls.
func Server(searcher zoekt.Searcher) http.Handler {
RegisterGob()
server := rpc.NewServer()
if err := server.Register(&srv.Searcher{Searcher: searcher}); err != nil {
// this should never fail, so we panic.
panic("unexpected error registering rpc server: " + err.Error())
}
return server
}
// Client connects to a Searcher HTTP RPC server at address (host:port) using
// DefaultRPCPath path.
func Client(address string) zoekt.Searcher {
return ClientAtPath(address, DefaultRPCPath)
}
// ClientAtPath connects to a Searcher HTTP RPC server at address and path
// (http://host:port/path).
func ClientAtPath(address, path string) zoekt.Searcher {
RegisterGob()
return &client{addr: address, path: path}
}
type client struct {
addr, path string
mu sync.Mutex // protects client and gen
cl *rpc.Client
gen int // incremented each time we dial
}
func (c *client) Search(ctx context.Context, q query.Q, opts *zoekt.SearchOptions) (*zoekt.SearchResult, error) {
var reply srv.SearchReply
err := c.call(ctx, "Searcher.Search", &srv.SearchArgs{Q: q, Opts: opts}, &reply)
return reply.Result, err
}
func (c *client) List(ctx context.Context, q query.Q, opts *zoekt.ListOptions) (*zoekt.RepoList, error) {
var reply srv.ListReply
err := c.call(ctx, "Searcher.List", &srv.ListArgs{Q: q, Opts: opts}, &reply)
return reply.List, err
}
func (c *client) call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error {
// We try twice. If we fail to dial or fail to call the function we try
// again after 100ms. Unrolled to make logic clear
cl, gen, err := c.getRPCClient(ctx, 0)
if err == nil {
err = cl.Call(ctx, serviceMethod, args, reply)
if err != rpc.ErrShutdown {
return err
}
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(100 * time.Millisecond):
}
cl, _, err = c.getRPCClient(ctx, gen)
if err != nil {
return err
}
return cl.Call(ctx, serviceMethod, args, reply)
}
// getRPCClient gets the rpc client. If gen matches the current generation, we
// redial and increment the generation. This is used to prevent concurrent
// redialing on network failure.
func (c *client) getRPCClient(ctx context.Context, gen int) (*rpc.Client, int, error) {
// coarse lock so we only dial once
c.mu.Lock()
defer c.mu.Unlock()
if gen != c.gen {
return c.cl, c.gen, nil
}
var timeout time.Duration
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
cl, err := rpc.DialHTTPPathTimeout("tcp", c.addr, c.path, timeout)
if err != nil {
return nil, c.gen, err
}
c.cl = cl
c.gen++
return c.cl, c.gen, nil
}
func (c *client) Close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.cl != nil {
c.cl.Close()
}
}
func (c *client) String() string {
return fmt.Sprintf("rpcSearcher(%s/%s)", c.addr, c.path)
}
var once sync.Once
// RegisterGob registers various query types with gob. It can be called more than
// once, because calls to gob.Register are protected by a sync.Once.
func RegisterGob() {
once.Do(func() {
gob.Register(&query.And{})
gob.Register(&query.BranchRepos{})
gob.Register(&query.BranchesRepos{})
gob.Register(&query.Branch{})
gob.Register(&query.Const{})
gob.Register(&query.GobCache{})
gob.Register(&query.Language{})
gob.Register(&query.Not{})
gob.Register(&query.Or{})
gob.Register(&query.Regexp{})
gob.Register(&query.RepoRegexp{})
gob.Register(&query.RepoSet{})
gob.Register(&query.Repo{})
gob.Register(&query.Substring{})
gob.Register(&query.Symbol{})
gob.Register(&query.Type{})
gob.Register(query.RawConfig(41))
})
}