Skip to content

Commit

Permalink
Creates default client only if required on relay startup (#912)
Browse files Browse the repository at this point in the history
* creates default client only if required on relay startup
  • Loading branch information
emmanuelm41 committed Feb 9, 2022
1 parent 9627b87 commit b0f85e3
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 48 deletions.
2 changes: 1 addition & 1 deletion client/test/http/mock/httpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func NewMockHTTPPublicServer(t *testing.T, badSecondRound bool, sch scheme.Schem
client := core.Proxy(server)
ctx, cancel := context.WithCancel(context.Background())

handler, err := dhttp.New(ctx, client, "", nil)
handler, err := dhttp.New(ctx, "", nil)
if err != nil {
t.Fatal(err)
}
Expand Down
53 changes: 27 additions & 26 deletions cmd/relay/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,34 +63,29 @@ func Relay(c *cli.Context) error {
return fmt.Errorf("--%s is deprecated on relay http, please use %s instead", lib.HashFlag.Name, lib.HashListFlag.Name)
}

client, err := lib.Create(c, c.IsSet(metricsFlag.Name))
if err != nil {
return err
}

handler, err := dhttp.New(c.Context, client, fmt.Sprintf("drand/%s (%s)", version, gitCommit), log.DefaultLogger().With("binary", "relay"))
handler, err := dhttp.New(c.Context, fmt.Sprintf("drand/%s (%s)", version, gitCommit), log.DefaultLogger().With("binary", "relay"))
if err != nil {
return fmt.Errorf("failed to create rest handler: %w", err)
}

hashesList := make([]string, 0)
hashesList = append(hashesList, common.DefaultChainHash)
hashesMap := make(map[string]bool)
if c.IsSet(lib.HashListFlag.Name) {
hashesList = c.StringSlice(lib.HashListFlag.Name)
}

for _, hash := range hashesList {
if hash == common.DefaultChainHash {
handler.RegisterNewBeaconHandler(client, common.DefaultChainHash)
continue
}

if _, err := hex.DecodeString(hash); err != nil {
return fmt.Errorf("failed to decode chain hash value: %s", err)
hashesList := c.StringSlice(lib.HashListFlag.Name)
for _, hash := range hashesList {
hashesMap[hash] = true
}
} else {
hashesMap[common.DefaultChainHash] = true
}

if err := c.Set(lib.HashFlag.Name, hash); err != nil {
return fmt.Errorf("failed to initiate chain hash handler: %s", err)
for hash := range hashesMap {
if hash != common.DefaultChainHash {
if _, err := hex.DecodeString(hash); err != nil {
return fmt.Errorf("failed to decode chain hash value: %s", err)
}
if err := c.Set(lib.HashFlag.Name, hash); err != nil {
return fmt.Errorf("failed to initiate chain hash handler: %s", err)
}
}

c, err := lib.Create(c, c.IsSet(metricsFlag.Name))
Expand Down Expand Up @@ -122,11 +117,17 @@ func Relay(c *cli.Context) error {
}

// jumpstart bootup
req, _ := http.NewRequest("GET", "/public/0", http.NoBody)
rr := httptest.NewRecorder()
handler.GetHTTPHandler().ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
log.DefaultLogger().Warnw("", "binary", "relay", "startup failed", rr.Code)
for hash := range hashesMap {
req, _ := http.NewRequest("GET", "/public/0", http.NoBody)
if hash != common.DefaultChainHash {
req, _ = http.NewRequest("GET", fmt.Sprintf("/%s/public/0", hash), http.NoBody)
}

rr := httptest.NewRecorder()
handler.GetHTTPHandler().ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
log.DefaultLogger().Warnw("", "binary", "relay", "chain-hash", hash, "startup failed", rr.Code)
}
}

fmt.Printf("Listening at %s\n", listener.Addr())
Expand Down
2 changes: 1 addition & 1 deletion core/drand_daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (dd *DrandDaemon) init() error {
var err error
dd.log.Infow("", "network", "init", "insecure", c.insecure)

handler, err := dhttp.New(ctx, &drandProxy{dd}, c.Version(), dd.log.With("server", "http"))
handler, err := dhttp.New(ctx, c.Version(), dd.log.With("server", "http"))
if err != nil {
return err
}
Expand Down
32 changes: 16 additions & 16 deletions http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type beaconHandler struct {
}

// New creates an HTTP handler for the public Drand API
func New(ctx context.Context, c client.Client, version string, logger log.Logger) (*DrandHandler, error) {
func New(ctx context.Context, version string, logger log.Logger) (*DrandHandler, error) {
if logger == nil {
logger = log.DefaultLogger()
}
Expand Down Expand Up @@ -242,40 +242,40 @@ func (h *DrandHandler) watchWithTimeout(bh *beaconHandler, ready chan bool) {
}
}

func (h *DrandHandler) getChainInfo(ctx context.Context, chainHash []byte) *chain.Info {
func (h *DrandHandler) getChainInfo(ctx context.Context, chainHash []byte) (*chain.Info, error) {
bh, err := h.getBeaconHandler(chainHash)
if err != nil {
return nil
return nil, err
}

bh.chainInfoLk.RLock()
if bh.chainInfo != nil {
info := bh.chainInfo
bh.chainInfoLk.RUnlock()
return info
return info, nil
}
bh.chainInfoLk.RUnlock()

bh.chainInfoLk.Lock()
defer bh.chainInfoLk.Unlock()

if bh.chainInfo != nil {
return bh.chainInfo
return bh.chainInfo, nil
}

ctx, cancel := context.WithTimeout(ctx, h.timeout)
defer cancel()
info, err := bh.client.Info(ctx)
if err != nil {
h.log.Warnw("", "msg", "chain info fetch failed", "err", err)
return nil
return nil, err
}
if info == nil {
h.log.Warnw("", "msg", "chain info fetch didn't return group info")
return nil
return nil, fmt.Errorf("chain info fetch didn't return group info")
}
bh.chainInfo = info
return info
return info, nil
}

func (h *DrandHandler) getRand(ctx context.Context, chainHash []byte, info *chain.Info, round uint64) ([]byte, error) {
Expand Down Expand Up @@ -358,9 +358,9 @@ func (h *DrandHandler) PublicRand(w http.ResponseWriter, r *http.Request) {
return
}

info := h.getChainInfo(r.Context(), chainHashHex)
info, err := h.getChainInfo(r.Context(), chainHashHex)
roundExpectedTime := time.Now()
if info == nil {
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
h.log.Warnw("", "http_server", "failed to get randomness", "client", r.RemoteAddr, "req", url.PathEscape(r.URL.Path), "err", err)
return
Expand Down Expand Up @@ -427,10 +427,10 @@ func (h *DrandHandler) LatestRand(w http.ResponseWriter, r *http.Request) {
return
}

info := h.getChainInfo(r.Context(), chainHashHex)
info, err := h.getChainInfo(r.Context(), chainHashHex)
roundTime := time.Now()
nextTime := time.Now()
if info != nil {
if err == nil {
roundTime = time.Unix(chain.TimeOfRound(info.Period, info.GenesisTime, resp.Round()), 0)
next := time.Unix(chain.TimeOfRound(info.Period, info.GenesisTime, resp.Round()+1), 0)
if next.After(nextTime) {
Expand Down Expand Up @@ -462,8 +462,8 @@ func (h *DrandHandler) ChainInfo(w http.ResponseWriter, r *http.Request) {
return
}

info := h.getChainInfo(r.Context(), chainHashHex)
if info == nil {
info, err := h.getChainInfo(r.Context(), chainHashHex)
if err != nil {
h.log.Warnw("", "http_server", "failed to serve group", "client", r.RemoteAddr, "req", url.PathEscape(r.URL.Path))
http.Error(w, "group not found", http.StatusNotFound)
return
Expand Down Expand Up @@ -505,7 +505,7 @@ func (h *DrandHandler) Health(w http.ResponseWriter, r *http.Request) {
lastSeen := bh.latestRound
bh.pendingLk.RUnlock()

info := h.getChainInfo(r.Context(), chainHashHex)
info, err := h.getChainInfo(r.Context(), chainHashHex)

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-cache")
Expand All @@ -514,7 +514,7 @@ func (h *DrandHandler) Health(w http.ResponseWriter, r *http.Request) {
resp["expected"] = 0
var b []byte

if info == nil {
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
} else {
expected := chain.CurrentRound(time.Now().Unix(), info.Period, info.GenesisTime)
Expand Down
8 changes: 4 additions & 4 deletions http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestHTTPRelay(t *testing.T) {

c, _ := withClient(t)

handler, err := New(ctx, c, "", nil)
handler, err := New(ctx, "", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestHTTPWaiting(t *testing.T) {
defer cancel()
c, push := withClient(t)

handler, err := New(ctx, c, "", nil)
handler, err := New(ctx, "", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestHTTPWatchFuture(t *testing.T) {
defer cancel()
c, _ := withClient(t)

handler, err := New(ctx, c, "", nil)
handler, err := New(ctx, "", nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -261,7 +261,7 @@ func TestHTTPHealth(t *testing.T) {
defer cancel()
c, push := withClient(t)

handler, err := New(ctx, c, "", nil)
handler, err := New(ctx, "", nil)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit b0f85e3

Please sign in to comment.