Skip to content

Commit

Permalink
feat(erigon): add api ratelimit and apikey (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
zibuyu28 committed Jul 22, 2024
1 parent 429c38b commit 2e6020c
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 16 deletions.
8 changes: 8 additions & 0 deletions cmd/rpcdaemon/cli/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ func RootCommand() (*cobra.Command, *httpcfg.HttpCfg) {
rootCmd.PersistentFlags().IntVar(&cfg.ReturnDataLimit, utils.RpcReturnDataLimit.Name, utils.RpcReturnDataLimit.Value, utils.RpcReturnDataLimit.Usage)

rootCmd.PersistentFlags().StringVar(&cfg.L2RpcUrl, utils.L2RpcUrlFlag.Name, utils.L2RpcUrlFlag.Value, utils.L2RpcUrlFlag.Usage)
rootCmd.PersistentFlags().StringVar(&cfg.HttpApiKeys, utils.HTTPApiKeysFlag.Name, utils.HTTPApiKeysFlag.Value, utils.HTTPApiKeysFlag.Usage)
rootCmd.PersistentFlags().StringVar(&cfg.MethodRateLimit, utils.MethodRateLimitFlag.Name, utils.MethodRateLimitFlag.Value, utils.MethodRateLimitFlag.Usage)

if err := rootCmd.MarkPersistentFlagFilename("rpc.accessList", "json"); err != nil {
panic(err)
Expand Down Expand Up @@ -505,6 +507,9 @@ func startRegularRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rp
log.Trace("TraceRequests = %t\n", cfg.TraceRequests)
srv := rpc.NewServer(cfg.RpcBatchConcurrency, cfg.TraceRequests, cfg.RpcStreamingDisable)

// For X Layer
rpc.InitRateLimit(cfg.MethodRateLimit)

allowListForRPC, err := parseAllowListForRPC(cfg.RpcAllowListFilePath)
if err != nil {
return err
Expand Down Expand Up @@ -545,6 +550,9 @@ func startRegularRpcServer(ctx context.Context, cfg httpcfg.HttpCfg, rpcAPI []rp
return err
}

// For X Layer
apiHandler = rpc.ApiAuthHandler(cfg.HttpApiKeys, apiHandler)

listener, httpAddr, err := node.StartHTTPEndpoint(httpEndpoint, cfg.HTTPTimeouts, apiHandler)
if err != nil {
return fmt.Errorf("could not start RPC api: %w", err)
Expand Down
8 changes: 5 additions & 3 deletions cmd/rpcdaemon/cli/httpcfg/http_cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ type HttpCfg struct {
ReturnDataLimit int // Maximum number of bytes returned from calls (like eth_call)

// zkevm
DataStreamPort int
DataStreamHost string
L2RpcUrl string
DataStreamPort int
DataStreamHost string
L2RpcUrl string
HttpApiKeys string
MethodRateLimit string
}
13 changes: 13 additions & 0 deletions cmd/utils/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,19 @@ var (
Usage: "API's offered over the HTTP-RPC interface",
Value: "eth,erigon,engine",
}
HTTPApiKeysFlag = cli.StringFlag{
Name: "http.apikeys",
Usage: `API keys for the HTTP-RPC server and you can add rate limit to this apikey , format:
{"project":"project1","key":"apikey1","timeout":"2023-12-12"}
{"project":"project2","key":"apikey2","timeout":"2023-12-12"}
{"project":"project3","key":"apikey3","timeout":"2023-12-12","methods":["method1","method2"],"count":1,"bucket":1}`,
Value: "",
}
MethodRateLimitFlag = cli.StringFlag{
Name: "http.methodratelimit",
Usage: "Method rate limit in requests per second, format: {\"method\":[\"method1\",\"method2\"],\"count\":1,\"bucket\":1}, eg. {\"methods\":[\"eth_call\",\"eth_blockNumber\"],\"count\":10,\"bucket\":1}",
Value: "",
}
L2ChainIdFlag = cli.Uint64Flag{
Name: "zkevm.l2-chain-id",
Usage: "L2 chain ID",
Expand Down
148 changes: 148 additions & 0 deletions rpc/api_auth_xlayer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package rpc

import (
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"net/http"
"path"
"strings"
"time"

"github.com/ledgerwatch/erigon/zkevm/jsonrpc/types"
"github.com/ledgerwatch/log/v3"
)

// ApiKeyItem is the api key item
type ApiKeyItem struct {
// Name defines the name of the key
Project string `json:"project"`
// Key defines the key
Key string `json:"key"`
// Timeout defines the timeout
Timeout string `json:"timeout"`
// Methods defines the methods
rateLimitConfig *RateLimitConfig
}

type apiAllow struct {
allowKeys map[string]keyItem
enable bool
}

type keyItem struct {
project string
timeout time.Time
}

var al apiAllow

// InitApiAuth initializes the api authentication
func InitApiAuth(apikeysconfig string) {
if apikeysconfig == "" {
return
}
log.Info("api auth enabled", "apikeysconfig", apikeysconfig)
keyItems := strings.Split(apikeysconfig, "\n")
var keys []ApiKeyItem

for _, item := range keyItems {
var itemins = struct {
// Name defines the name of the key
Project string `json:"project"`
Key string `json:"key"`
Timeout string `json:"timeout"`
Methods []string `json:"methods"`
Count int `json:"count"`
Bucket int `json:"bucket"`
}{}
err := json.Unmarshal([]byte(item), &itemins)
if err != nil {
log.Warn("invalid key item: %s", item)
continue
}
apiKeyItem := ApiKeyItem{Project: itemins.Project, Key: itemins.Key, Timeout: itemins.Timeout}
if len(itemins.Methods) > 0 {
rlc := RateLimitConfig{
RateLimitApis: itemins.Methods,
RateLimitCount: itemins.Count,
RateLimitBucket: itemins.Bucket,
}
apiKeyItem.rateLimitConfig = &rlc
}
keys = append(keys, apiKeyItem)
}
setApiAuth(keys)
}

// setApiAuth sets the api authentication
func setApiAuth(kis []ApiKeyItem) {
al.enable = len(kis) > 0
var tmp = make(map[string]keyItem)
var rateLimitConfig = make(map[string]*RateLimitConfig)
for _, k := range kis {
k.Key = strings.ToLower(k.Key)
parse, err := time.Parse("2006-01-02", k.Timeout)
if err != nil {
log.Warn("parse key [%+v], error parsing timeout: %v", k, err)
continue
}
if strings.ToLower(fmt.Sprintf("%x", md5.Sum([]byte(k.Project+k.Timeout)))) != k.Key {
log.Warn("project [%s], key [%s] is invalid, key = md5(Project+Timeout)", k.Project, k.Key)
continue
}
tmp[k.Key] = keyItem{project: k.Project, timeout: parse}
if k.rateLimitConfig != nil {
rateLimitConfig[k.Key] = k.rateLimitConfig
}
}
al.allowKeys = tmp
initApikeyRateLimit(rateLimitConfig)
}

func check(key string) error {
key = strings.ToLower(key)
if item, ok := al.allowKeys[key]; ok && time.Now().Before(item.timeout) {
//metrics.RequestAuthCount(al.allowKeys[key].project)
return nil
} else if ok && time.Now().After(item.timeout) {
log.Warn("project [%s], key [%s] has expired, ", item.project, key)
//metrics.RequestAuthErrorCount(metrics.RequestAuthErrorTypeKeyExpired)
return errors.New("key has expired")
}
//metrics.RequestAuthErrorCount(metrics.RequestAuthErrorTypeNoAuth)
return errors.New("no authentication")
}

func apiAuthHandlerFunc(cfg string, handlerFunc http.HandlerFunc) http.HandlerFunc {
InitApiAuth(cfg)
return func(w http.ResponseWriter, r *http.Request) {
if al.enable {
if er := check(path.Base(r.URL.Path)); er != nil {
err := handleNoAuthErr(w, er)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}
}
handlerFunc(w, r)
}
}

func ApiAuthHandler(cfg string, next http.Handler) http.Handler {
return apiAuthHandlerFunc(cfg, next.ServeHTTP)
}

func handleNoAuthErr(w http.ResponseWriter, err error) error {
respbytes, err := types.NewResponse(types.Request{JSONRPC: "2.0", ID: 0}, nil, types.NewRPCError(types.InvalidParamsErrorCode, err.Error())).Bytes()
if err != nil {
return err
}
_, err = w.Write(respbytes)
if err != nil {
return err
}
return nil
}
10 changes: 7 additions & 3 deletions rpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ type Client struct {
services *serviceRegistry
methodAllowList AllowList

// apiKey is the API key used for authentication. For X Layer
apiKey string

idCounter uint32

// This function, if non-nil, is called when the connection is lost.
Expand Down Expand Up @@ -112,7 +115,7 @@ type clientConn struct {

func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */)
handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList, 50, false /* traceRequests */, c.apiKey)
return &clientConn{conn, handler}
}

Expand Down Expand Up @@ -196,17 +199,18 @@ func newClient(initctx context.Context, connect reconnectFunc) (*Client, error)
if err != nil {
return nil, err
}
c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
c := initClient(conn, randomIDGenerator(), new(serviceRegistry), "")
c.reconnectFunc = connect
return c, nil
}

func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry, apikey string) *Client {
_, isHTTP := conn.(*httpConn)
c := &Client{
idgen: idgen,
isHTTP: isHTTP,
services: services,
apiKey: apikey, // For X Layer
writeConn: conn,
close: make(chan struct{}),
closing: make(chan struct{}),
Expand Down
15 changes: 13 additions & 2 deletions rpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type handler struct {

allowList AllowList // a list of explicitly allowed methods, if empty -- everything is allowed
forbiddenList ForbiddenList
ApiKey string

subLock sync.Mutex
serverSubs map[ID]*Subscription
Expand Down Expand Up @@ -110,7 +111,7 @@ func HandleError(err error, stream *jsoniter.Stream) error {
return nil
}

func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList, maxBatchConcurrency uint, traceRequests bool) *handler {
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList, maxBatchConcurrency uint, traceRequests bool, apikey string) *handler {
rootCtx, cancelRoot := context.WithCancel(connCtx)
forbiddenList := newForbiddenList()
h := &handler{
Expand All @@ -126,6 +127,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
log: log.Root(),
allowList: allowList,
forbiddenList: forbiddenList,
ApiKey: apikey,

maxBatchConcurrency: maxBatchConcurrency,
traceRequests: traceRequests,
Expand Down Expand Up @@ -174,7 +176,11 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
wg.Done()
<-boundedConcurrency
}()

// For X Layer
if !apikeyMethodRateLimitAllow(h.ApiKey, calls[i].Method) || !methodRateLimitAllow(calls[i].Method) {
answersWithNils[i] = errorMessage(fmt.Errorf("method rate limit exceeded"))
return
}
select {
case <-cp.ctx.Done():
return
Expand Down Expand Up @@ -215,6 +221,11 @@ func (h *handler) handleMsg(msg *jsonrpcMessage, stream *jsoniter.Stream) {
return
}
h.startCallProc(func(cp *callProc) {
// For X Layer
if !apikeyMethodRateLimitAllow(h.ApiKey, msg.Method) || !methodRateLimitAllow(msg.Method) {
h.conn.writeJSON(cp.ctx, errorMessage(fmt.Errorf("rate limit exceeded")))
return
}
needWriteStream := false
if stream == nil {
stream = jsoniter.NewStream(jsoniter.ConfigDefault, nil, 4096)
Expand Down
4 changes: 4 additions & 0 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"mime"
"net/http"
"net/url"
"path"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -247,6 +248,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx = context.WithValue(ctx, "Origin", origin)
}

// For X Layer
ctx = context.WithValue(ctx, "apikey", strings.TrimPrefix(path.Base(r.URL.Path), "/"))

w.Header().Set("content-type", contentType)
codec := newHTTPServerConn(r, w)
defer codec.close()
Expand Down
Loading

0 comments on commit 2e6020c

Please sign in to comment.