diff --git a/clients/agentgrpc/encoding.go b/clients/agentgrpc/encoding.go deleted file mode 100644 index f1104631..00000000 --- a/clients/agentgrpc/encoding.go +++ /dev/null @@ -1,43 +0,0 @@ -package agentgrpc - -import ( - "encoding/binary" - "fmt" - "reflect" - "unsafe" - - "google.golang.org/grpc" - "google.golang.org/grpc/encoding" - "google.golang.org/grpc/encoding/proto" -) - -var ( - defaultCodec = encoding.GetCodec(proto.Name) - destPrepMsgType = reflect.TypeOf(&grpc.PreparedMsg{}) -) - -type preparedMsg struct { - encodedData []byte - hdr []byte - payload []byte -} - -// EncodeMessage encodes request as a PreparedMsg so the client stream can use it -// directly instead of allocating a new encoded message. -// -// See https://github.com/grpc/grpc-go/blob/1ffd63de37de4571028efedb6422e29d08716d0c/stream.go#L1623 -func EncodeMessage(msg interface{}) (*grpc.PreparedMsg, error) { - msgB, err := defaultCodec.Marshal(msg) - if err != nil { - return nil, fmt.Errorf("agentgrpc: failed to encode message: %v", err) - } - hdr := make([]byte, 5) - // write length of payload into header buffer - binary.BigEndian.PutUint32(hdr[1:], uint32(len(msgB))) - // hacky conversion to avoid compiler error - return (*grpc.PreparedMsg)((unsafe.Pointer)(&preparedMsg{ - encodedData: msgB, - payload: msgB, - hdr: hdr, - })), nil -} diff --git a/clients/agentgrpc/encoding_bench_test.go b/clients/agentgrpc/encoding_bench_test.go deleted file mode 100644 index 512b351d..00000000 --- a/clients/agentgrpc/encoding_bench_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package agentgrpc_test - -import ( - "context" - "encoding/json" - "fmt" - "io/ioutil" - "log" - "testing" - "time" - - "github.com/forta-network/forta-core-go/protocol" - "github.com/forta-network/forta-node/clients/agentgrpc" - "github.com/forta-network/forta-node/config" - "google.golang.org/grpc" -) - -var ( - benchBlockMsg = &protocol.EvaluateBlockRequest{} - benchTxMsg = &protocol.EvaluateTxRequest{} -) - -func init() { - b, err := ioutil.ReadFile("./testdata/bench_block.json") - if err != nil { - panic(err) - } - if err := json.Unmarshal(b, &benchBlockMsg.Event); err != nil { - panic(err) - } - b, err = ioutil.ReadFile("./testdata/bench_tx.json") - if err != nil { - panic(err) - } - if err := json.Unmarshal(b, &benchTxMsg.Event); err != nil { - panic(err) - } -} - -const benchAgentReqCount = 25 - -func getBenchClient() agentgrpc.Client { - agentClient := agentgrpc.NewClient() - for { - conn, err := grpc.Dial(fmt.Sprintf("localhost:%s", config.AgentGrpcPort), grpc.WithInsecure()) - if err == nil { - agentClient.WithConn(conn) - var success bool - _, err1 := agentClient.EvaluateBlock(context.Background(), benchBlockMsg) - _, err2 := agentClient.EvaluateTx(context.Background(), benchTxMsg) - success = (err1 == nil) && (err2 == nil) - if success { - break - } - } - time.Sleep(time.Second * 2) - log.Println("retrying to connect to grpc server") - } - - return agentClient -} - -func BenchmarkEvaluateBlock(b *testing.B) { - agentClient := getBenchClient() - for i := 0; i < b.N; i++ { - for j := 0; j < benchAgentReqCount; j++ { - out, err := agentClient.EvaluateBlock(context.Background(), benchBlockMsg) - if err != nil { - panic(err) - } - _ = out - } - } -} - -func BenchmarkEvaluateBlockWithPreparedMessage(b *testing.B) { - agentClient := getBenchClient() - for i := 0; i < b.N; i++ { - preparedMsg, err := agentgrpc.EncodeMessage(benchBlockMsg) - if err != nil { - panic(err) - } - for j := 0; j < benchAgentReqCount; j++ { - var resp protocol.EvaluateBlockResponse - err := agentClient.Invoke(context.Background(), agentgrpc.MethodEvaluateBlock, preparedMsg, &resp) - if err != nil { - panic(err) - } - } - } -} - -func BenchmarkEvaluateTx(b *testing.B) { - agentClient := getBenchClient() - for i := 0; i < b.N; i++ { - for j := 0; j < benchAgentReqCount; j++ { - out, err := agentClient.EvaluateTx(context.Background(), benchTxMsg) - if err != nil { - panic(err) - } - _ = out - } - } -} - -func BenchmarkEvaluateTxWithPreparedMessage(b *testing.B) { - agentClient := getBenchClient() - for i := 0; i < b.N; i++ { - preparedMsg, err := agentgrpc.EncodeMessage(benchTxMsg) - if err != nil { - panic(err) - } - for j := 0; j < benchAgentReqCount; j++ { - var resp protocol.EvaluateTxResponse - err := agentClient.Invoke(context.Background(), agentgrpc.MethodEvaluateTx, preparedMsg, &resp) - if err != nil { - panic(err) - } - } - } -} diff --git a/clients/agentgrpc/encoding_test.go b/clients/agentgrpc/encoding_test.go deleted file mode 100644 index 37168d24..00000000 --- a/clients/agentgrpc/encoding_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package agentgrpc_test - -import ( - "context" - "fmt" - "log" - "net" - "testing" - - "github.com/forta-network/forta-core-go/protocol" - "github.com/forta-network/forta-node/clients/agentgrpc" - "github.com/forta-network/forta-node/config" - "github.com/stretchr/testify/require" - "google.golang.org/grpc" -) - -var txMsg = &protocol.EvaluateTxRequest{ - RequestId: "123", - Event: &protocol.TransactionEvent{ - Type: protocol.TransactionEvent_BLOCK, - Transaction: &protocol.TransactionEvent_EthTransaction{ - Hash: "0xa3f0ad74e5423aebfd80d3ef4346578335a9a72aeaee59ff6cb3582b35133d50", - }, - }, -} - -type agentServer struct { - r *require.Assertions - doneCh chan struct{} - disableAssertion bool - protocol.UnimplementedAgentServer -} - -func (as *agentServer) Initialize(context.Context, *protocol.InitializeRequest) (*protocol.InitializeResponse, error) { - return &protocol.InitializeResponse{ - Status: protocol.ResponseStatus_SUCCESS, - }, nil -} - -func (as *agentServer) EvaluateTx(ctx context.Context, txRequest *protocol.EvaluateTxRequest) (*protocol.EvaluateTxResponse, error) { - if !as.disableAssertion { - as.r.Equal(txMsg.RequestId, txRequest.RequestId) - as.r.Equal(txMsg.Event.Transaction.Hash, txRequest.Event.Transaction.Hash) - close(as.doneCh) - } - return &protocol.EvaluateTxResponse{ - Status: protocol.ResponseStatus_SUCCESS, - }, nil -} - -func (as *agentServer) EvaluateBlock(context.Context, *protocol.EvaluateBlockRequest) (*protocol.EvaluateBlockResponse, error) { - return &protocol.EvaluateBlockResponse{ - Status: protocol.ResponseStatus_SUCCESS, - }, nil -} - -func TestEncodeMessage(t *testing.T) { - r := require.New(t) - - preparedMsg, err := agentgrpc.EncodeMessage(txMsg) - r.NoError(err) - log.Printf("%+v", preparedMsg) - - lis, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%s", config.AgentGrpcPort)) - r.NoError(err) - defer lis.Close() - - server := grpc.NewServer() - as := &agentServer{r: r, doneCh: make(chan struct{})} - protocol.RegisterAgentServer(server, as) - go server.Serve(lis) - - agentClient := agentgrpc.NewClient() - conn, err := grpc.Dial(fmt.Sprintf("localhost:%s", config.AgentGrpcPort), grpc.WithInsecure()) - r.NoError(err) - agentClient.WithConn(conn) - - var resp protocol.EvaluateTxResponse - r.NoError(agentClient.Invoke(context.Background(), agentgrpc.MethodEvaluateTx, preparedMsg, &resp)) - <-as.doneCh -} diff --git a/go.mod b/go.mod index dbb566de..01ceca5c 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/docker/go-connections v0.4.0 github.com/ethereum/go-ethereum v1.10.16 github.com/fatih/color v1.13.0 + github.com/forta-network/forta-core-go v0.0.0-20230605171938-4067381cbaea github.com/go-playground/validator/v10 v10.9.0 github.com/go-redis/redis v6.15.9+incompatible github.com/goccy/go-json v0.9.4 @@ -18,10 +19,16 @@ require ( github.com/gorilla/mux v1.8.0 github.com/ipfs/go-cid v0.3.2 github.com/ipfs/go-ipfs-api v0.3.0 + github.com/libp2p/go-libp2p v0.23.2 github.com/nats-io/nats-server/v2 v2.3.2 // indirect github.com/nats-io/nats.go v1.11.1-0.20210623165838-4b75fc59ae30 github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/rs/cors v1.7.0 + github.com/shopspring/decimal v1.3.1 + github.com/sirupsen/logrus v1.8.1 + github.com/spf13/cobra v0.0.5 github.com/spf13/viper v1.3.2 github.com/stretchr/testify v1.8.0 github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43 @@ -31,16 +38,6 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) -require ( - github.com/forta-network/forta-core-go v0.0.0-20230601180321-91aaa41a0fb7 - github.com/libp2p/go-libp2p v0.23.2 - github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/rs/cors v1.7.0 - github.com/shopspring/decimal v1.3.1 - github.com/sirupsen/logrus v1.8.1 - github.com/spf13/cobra v0.0.5 -) - require ( bazil.org/fuse v0.0.0-20200117225306-7b5117fecadc // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect diff --git a/go.sum b/go.sum index 3ea99280..cd70f92f 100644 --- a/go.sum +++ b/go.sum @@ -325,8 +325,8 @@ github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ= github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= -github.com/forta-network/forta-core-go v0.0.0-20230601180321-91aaa41a0fb7 h1:C3dGzZMW9G8kf+acq+zm4vMaz5tCG/TgFp7jC6D/ur8= -github.com/forta-network/forta-core-go v0.0.0-20230601180321-91aaa41a0fb7/go.mod h1:gffFqv24ErxEILzjvhXCaVHa2dljzdILlaJyUlkDnfw= +github.com/forta-network/forta-core-go v0.0.0-20230605171938-4067381cbaea h1:crNFjYYoqw1l5HuJPuGPOqizhxxZtjvDqAQzlBYiYh0= +github.com/forta-network/forta-core-go v0.0.0-20230605171938-4067381cbaea/go.mod h1:gffFqv24ErxEILzjvhXCaVHa2dljzdILlaJyUlkDnfw= github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= diff --git a/services/components/botio/bot_client.go b/services/components/botio/bot_client.go index 48a0a288..eee95193 100644 --- a/services/components/botio/bot_client.go +++ b/services/components/botio/bot_client.go @@ -30,8 +30,6 @@ type BotClient interface { Config() config.AgentConfig SetConfig(config.AgentConfig) - Started() <-chan struct{} - IsStarted() bool Initialized() <-chan struct{} IsInitialized() bool Closed() <-chan struct{} @@ -58,15 +56,16 @@ type BotClient interface { // Constants const ( - DefaultBufferSize = 2000 - AgentTimeout = 30 * time.Second - MaxFindings = 50 - DefaultAgentInitializeTimeout = 5 * time.Minute + DefaultBufferSize = 2000 + RequestTimeout = 30 * time.Second + MaxFindings = 50 + DefaultInitializeTimeout = 5 * time.Minute ) // botClient receives blocks and transactions, and produces results. type botClient struct { ctx context.Context + ctxCancel func() configUnsafe config.AgentConfig alertConfigUnsafe protocol.AlertConfig @@ -83,12 +82,10 @@ type botClient struct { dialer agentgrpc.BotDialer clientUnsafe agentgrpc.Client - started chan struct{} - startedOnce sync.Once initialized chan struct{} initializedOnce sync.Once - closed chan struct{} - closeOnce sync.Once + + closeOnce sync.Once mu sync.RWMutex } @@ -129,8 +126,10 @@ func NewBotClient( msgClient clients.MessageClient, lifecycleMetrics metrics.Lifecycle, botDialer agentgrpc.BotDialer, resultChannels botreq.SendOnlyChannels, ) *botClient { + botCtx, botCtxCancel := context.WithCancel(ctx) return &botClient{ - ctx: ctx, + ctx: botCtx, + ctxCancel: botCtxCancel, configUnsafe: botCfg, txRequests: make(chan *botreq.TxRequest, DefaultBufferSize), blockRequests: make(chan *botreq.BlockRequest, DefaultBufferSize), @@ -140,9 +139,7 @@ func NewBotClient( msgClient: msgClient, lifecycleMetrics: lifecycleMetrics, dialer: botDialer, - started: make(chan struct{}), initialized: make(chan struct{}), - closed: make(chan struct{}), } } @@ -159,7 +156,6 @@ func (bot *botClient) LogStatus() { "bot": bot.Config().ID, "blockBuffer": len(bot.blockRequests), "txBuffer": len(bot.txRequests), - "started": bot.IsStarted(), "initialized": bot.IsInitialized(), "closed": bot.IsClosed(), }).Debug("bot status") @@ -240,7 +236,7 @@ func (bot *botClient) CombinationRequestCh() chan<- *botreq.CombinationRequest { // Close implements io.Closer. func (bot *botClient) Close() error { bot.closeOnce.Do(func() { - close(bot.closed) // never close this anywhere else + bot.ctxCancel() client := bot.grpcClient() if client != nil { client.Close() @@ -258,12 +254,12 @@ func (bot *botClient) Close() error { // Closed returns the closed channel. func (bot *botClient) Closed() <-chan struct{} { - return bot.closed + return bot.ctx.Done() } // IsClosed tells if the bot is closed. func (bot *botClient) IsClosed() bool { - return isChanClosed(bot.closed) + return isChanClosed(bot.ctx.Done()) } // setInitialized sets the bot as initialized. @@ -285,26 +281,7 @@ func (bot *botClient) IsInitialized() bool { return isChanClosed(bot.initialized) } -// setStarted sets the bot as started. -func (bot *botClient) setStarted() { - bot.startedOnce.Do( - func() { - close(bot.started) // never close this anywhere else - }, - ) -} - -// Started returns the started channel. -func (bot *botClient) Started() <-chan struct{} { - return bot.started -} - -// IsStarted tells if the bot has been started. -func (bot *botClient) IsStarted() bool { - return isChanClosed(bot.started) -} - -func isChanClosed(ch chan struct{}) bool { +func isChanClosed(ch <-chan struct{}) bool { select { case _, ok := <-ch: return !ok @@ -313,14 +290,6 @@ func isChanClosed(ch chan struct{}) bool { } } -// StartProcessing launches the goroutines to concurrently process incoming requests -// from request channels. -func (bot *botClient) StartProcessing() { - go bot.processTransactions() - go bot.processBlocks() - go bot.processCombinationAlerts() -} - // Initialize initializes the bot. func (bot *botClient) Initialize() { bot.initialize() @@ -335,7 +304,6 @@ func (bot *botClient) initialize() { // publish start metric to track bot starts/restarts. bot.lifecycleMetrics.ClientDial(botConfig) - bot.setStarted() botClient, err := bot.dialer.DialBot(botConfig) if err != nil { @@ -346,7 +314,7 @@ func (bot *botClient) initialize() { bot.lifecycleMetrics.StatusAttached(botConfig) logger.Info("attached to bot") - ctx, cancel := context.WithTimeout(bot.ctx, DefaultAgentInitializeTimeout) + ctx, cancel := context.WithTimeout(bot.ctx, DefaultInitializeTimeout) defer cancel() // invoke initialize method of the bot @@ -415,25 +383,77 @@ func validateInitializeResponse(response *protocol.InitializeResponse) error { return nil } +// StartProcessing launches the goroutines to concurrently process incoming requests +// from request channels. +func (bot *botClient) StartProcessing() { + go bot.processTransactions() + go bot.processBlocks() + go bot.processCombinationAlerts() +} + +func processRequests[R any]( + ctx context.Context, reqCh <-chan *R, closedCh <-chan struct{}, logger *log.Entry, + processFunc func(context.Context, *log.Entry, *R) bool, +) { + for { + select { + case <-ctx.Done(): + logger.WithError(ctx.Err()).Info("bot context is done") + return + + case request := <-reqCh: + ctx, cancel := context.WithTimeout(ctx, RequestTimeout) + exit := processFunc(ctx, logger, request) + cancel() + if exit { + return + } + } + } +} + func (bot *botClient) processTransactions() { lg := log.WithFields( log.Fields{ "bot": bot.Config().ID, - "component": "pool-bot", + "component": "bot-client", "evaluate": "transaction", }, ) <-bot.Initialized() - for request := range bot.txRequests { - if exit := bot.processTransaction(lg, request); exit { - return - } - } + processRequests(bot.ctx, bot.txRequests, bot.Closed(), lg, bot.processTransaction) +} +func (bot *botClient) processBlocks() { + lg := log.WithFields( + log.Fields{ + "bot": bot.Config().ID, + "component": "bot-client", + "evaluate": "block", + }, + ) + + <-bot.Initialized() + + processRequests(bot.ctx, bot.blockRequests, bot.Closed(), lg, bot.processBlock) } -func (bot *botClient) processTransaction(lg *log.Entry, request *botreq.TxRequest) (exit bool) { +func (bot *botClient) processCombinationAlerts() { + lg := log.WithFields( + log.Fields{ + "bot": bot.Config().ID, + "component": "bot-client", + "evaluate": "combination", + }, + ) + + <-bot.Initialized() + + processRequests(bot.ctx, bot.combinationRequests, bot.Closed(), lg, bot.processCombinationAlert) +} + +func (bot *botClient) processTransaction(ctx context.Context, lg *log.Entry, request *botreq.TxRequest) (exit bool) { botConfig := bot.Config() botClient := bot.grpcClient() @@ -443,14 +463,13 @@ func (bot *botClient) processTransaction(lg *log.Entry, request *botreq.TxReques startTime := time.Now() - ctx, cancel := context.WithTimeout(bot.ctx, AgentTimeout) lg.WithField("duration", time.Since(startTime)).Debugf("sending request") resp := new(protocol.EvaluateTxResponse) requestTime := time.Now().UTC() - err := botClient.Invoke(ctx, agentgrpc.MethodEvaluateTx, request.Encoded, resp) + err := botClient.Invoke(ctx, agentgrpc.MethodEvaluateTx, request.Original, resp) responseTime := time.Now().UTC() - cancel() + if err == nil { // truncate findings if len(resp.Findings) > MaxFindings { @@ -501,25 +520,7 @@ func (bot *botClient) processTransaction(lg *log.Entry, request *botreq.TxReques return false } -func (bot *botClient) processBlocks() { - lg := log.WithFields( - log.Fields{ - "bot": bot.Config().ID, - "component": "bot", - "evaluate": "block", - }, - ) - - <-bot.Initialized() - - for request := range bot.blockRequests { - if exit := bot.processBlock(lg, request); exit { - return - } - } -} - -func (bot *botClient) processBlock(lg *log.Entry, request *botreq.BlockRequest) (exit bool) { +func (bot *botClient) processBlock(ctx context.Context, lg *log.Entry, request *botreq.BlockRequest) (exit bool) { botConfig := bot.Config() botClient := bot.grpcClient() @@ -529,13 +530,12 @@ func (bot *botClient) processBlock(lg *log.Entry, request *botreq.BlockRequest) startTime := time.Now() - ctx, cancel := context.WithTimeout(bot.ctx, AgentTimeout) lg.WithField("duration", time.Since(startTime)).Debugf("sending request") resp := new(protocol.EvaluateBlockResponse) requestTime := time.Now().UTC() - err := botClient.Invoke(ctx, agentgrpc.MethodEvaluateBlock, request.Encoded, resp) + err := botClient.Invoke(ctx, agentgrpc.MethodEvaluateBlock, request.Original, resp) responseTime := time.Now().UTC() - cancel() + if err == nil { // truncate findings if len(resp.Findings) > MaxFindings { @@ -588,25 +588,7 @@ func (bot *botClient) processBlock(lg *log.Entry, request *botreq.BlockRequest) return false } -func (bot *botClient) processCombinationAlerts() { - lg := log.WithFields( - log.Fields{ - "bot": bot.Config().ID, - "component": "bot", - "evaluate": "combination", - }, - ) - - <-bot.Initialized() - - for request := range bot.combinationRequests { - if exit := bot.processCombinationAlert(lg, request); exit { - return - } - } -} - -func (bot *botClient) processCombinationAlert(lg *log.Entry, request *botreq.CombinationRequest) bool { +func (bot *botClient) processCombinationAlert(ctx context.Context, lg *log.Entry, request *botreq.CombinationRequest) bool { botConfig := bot.Config() botClient := bot.grpcClient() @@ -616,13 +598,11 @@ func (bot *botClient) processCombinationAlert(lg *log.Entry, request *botreq.Com startTime := time.Now() - ctx, cancel := context.WithTimeout(bot.ctx, AgentTimeout) lg.WithField("duration", time.Since(startTime)).Debugf("sending request") resp := new(protocol.EvaluateAlertResponse) requestTime := time.Now().UTC() - err := botClient.Invoke(ctx, agentgrpc.MethodEvaluateAlert, request.Encoded, resp) + err := botClient.Invoke(ctx, agentgrpc.MethodEvaluateAlert, request.Original, resp) responseTime := time.Now().UTC() - cancel() if err != nil { if status.Code(err) != codes.Unimplemented { diff --git a/services/components/botio/bot_client_test.go b/services/components/botio/bot_client_test.go index 144b4c52..bc8f53d5 100644 --- a/services/components/botio/bot_client_test.go +++ b/services/components/botio/bot_client_test.go @@ -15,7 +15,6 @@ import ( "github.com/forta-network/forta-node/config" "github.com/forta-network/forta-node/services/components/botio/botreq" mock_metrics "github.com/forta-network/forta-node/services/components/metrics/mocks" - "google.golang.org/grpc" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -104,13 +103,9 @@ func (s *BotClientSuite) TestStartProcessStop() { }, }, } - encodedTxReq, err := agentgrpc.EncodeMessage(txReq) - s.r.NoError(err) txResp := &protocol.EvaluateTxResponse{Metadata: map[string]string{"imageHash": ""}} blockReq := &protocol.EvaluateBlockRequest{Event: &protocol.BlockEvent{BlockNumber: "123123"}} - encodedBlockReq, err := agentgrpc.EncodeMessage(blockReq) - s.r.NoError(err) blockResp := &protocol.EvaluateBlockResponse{Metadata: map[string]string{"imageHash": ""}} combinerReq := &protocol.EvaluateAlertRequest{ @@ -123,17 +118,14 @@ func (s *BotClientSuite) TestStartProcessStop() { }, }, } - encodedCombinerReq, err := agentgrpc.EncodeMessage(combinerReq) - s.r.NoError(err) combinerResp := &protocol.EvaluateAlertResponse{Metadata: map[string]string{"imageHash": ""}} // test tx handling s.botGrpc.EXPECT().Invoke( gomock.Any(), agentgrpc.MethodEvaluateTx, - gomock.AssignableToTypeOf(&grpc.PreparedMsg{}), gomock.AssignableToTypeOf(&protocol.EvaluateTxResponse{}), + gomock.AssignableToTypeOf(&protocol.EvaluateTxRequest{}), gomock.AssignableToTypeOf(&protocol.EvaluateTxResponse{}), ).Return(nil) s.botClient.TxRequestCh() <- &botreq.TxRequest{ - Encoded: encodedTxReq, Original: txReq, } txResult := <-s.resultChannels.Tx @@ -143,10 +135,9 @@ func (s *BotClientSuite) TestStartProcessStop() { // test block handling s.botGrpc.EXPECT().Invoke( gomock.Any(), agentgrpc.MethodEvaluateBlock, - gomock.AssignableToTypeOf(&grpc.PreparedMsg{}), gomock.AssignableToTypeOf(&protocol.EvaluateBlockResponse{}), + gomock.AssignableToTypeOf(&protocol.EvaluateBlockRequest{}), gomock.AssignableToTypeOf(&protocol.EvaluateBlockResponse{}), ).Return(nil) s.botClient.BlockRequestCh() <- &botreq.BlockRequest{ - Encoded: encodedBlockReq, Original: blockReq, } blockResult := <-s.resultChannels.Block @@ -156,10 +147,9 @@ func (s *BotClientSuite) TestStartProcessStop() { // test combine alert handling s.botGrpc.EXPECT().Invoke( gomock.Any(), agentgrpc.MethodEvaluateAlert, - gomock.AssignableToTypeOf(&grpc.PreparedMsg{}), gomock.AssignableToTypeOf(&protocol.EvaluateAlertResponse{}), + gomock.AssignableToTypeOf(&protocol.EvaluateAlertRequest{}), gomock.AssignableToTypeOf(&protocol.EvaluateAlertResponse{}), ).Return(nil) s.botClient.CombinationRequestCh() <- &botreq.CombinationRequest{ - Encoded: encodedCombinerReq, Original: combinerReq, } alertResult := <-s.resultChannels.CombinationAlert diff --git a/services/components/botio/botreq/request.go b/services/components/botio/botreq/request.go index 23b855d7..050d69c6 100644 --- a/services/components/botio/botreq/request.go +++ b/services/components/botio/botreq/request.go @@ -2,23 +2,19 @@ package botreq import ( "github.com/forta-network/forta-core-go/protocol" - "google.golang.org/grpc" ) -// TxRequest contains the original request data and the encoded message. +// TxRequest contains the request data. type TxRequest struct { Original *protocol.EvaluateTxRequest - Encoded *grpc.PreparedMsg } -// BlockRequest contains the original request data and the encoded message. +// BlockRequest contains the request data. type BlockRequest struct { Original *protocol.EvaluateBlockRequest - Encoded *grpc.PreparedMsg } -// CombinationRequest contains the original request data and the encoded message. +// CombinationRequest contains the request data. type CombinationRequest struct { Original *protocol.EvaluateAlertRequest - Encoded *grpc.PreparedMsg } diff --git a/services/components/botio/mocks/mock_bot_client.go b/services/components/botio/mocks/mock_bot_client.go index 69528190..b9f036eb 100644 --- a/services/components/botio/mocks/mock_bot_client.go +++ b/services/components/botio/mocks/mock_bot_client.go @@ -175,20 +175,6 @@ func (mr *MockBotClientMockRecorder) IsInitialized() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsInitialized", reflect.TypeOf((*MockBotClient)(nil).IsInitialized)) } -// IsStarted mocks base method. -func (m *MockBotClient) IsStarted() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsStarted") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsStarted indicates an expected call of IsStarted. -func (mr *MockBotClientMockRecorder) IsStarted() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsStarted", reflect.TypeOf((*MockBotClient)(nil).IsStarted)) -} - // LogStatus mocks base method. func (m *MockBotClient) LogStatus() { m.ctrl.T.Helper() @@ -253,20 +239,6 @@ func (mr *MockBotClientMockRecorder) StartProcessing() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartProcessing", reflect.TypeOf((*MockBotClient)(nil).StartProcessing)) } -// Started mocks base method. -func (m *MockBotClient) Started() <-chan struct{} { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Started") - ret0, _ := ret[0].(<-chan struct{}) - return ret0 -} - -// Started indicates an expected call of Started. -func (mr *MockBotClientMockRecorder) Started() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Started", reflect.TypeOf((*MockBotClient)(nil).Started)) -} - // TxBufferIsFull mocks base method. func (m *MockBotClient) TxBufferIsFull() bool { m.ctrl.T.Helper() diff --git a/services/components/botio/sender.go b/services/components/botio/sender.go index a408b19d..a101debe 100644 --- a/services/components/botio/sender.go +++ b/services/components/botio/sender.go @@ -9,7 +9,6 @@ import ( "github.com/forta-network/forta-core-go/clients/health" "github.com/forta-network/forta-core-go/protocol" "github.com/forta-network/forta-node/clients" - "github.com/forta-network/forta-node/clients/agentgrpc" "github.com/forta-network/forta-node/clients/messaging" "github.com/forta-network/forta-node/services/components/botio/botreq" "github.com/forta-network/forta-node/services/components/metrics" @@ -94,11 +93,6 @@ func (rs *requestSender) SendEvaluateTxRequest(req *protocol.EvaluateTxRequest) bots := rs.botPool.GetCurrentBotClients() - encoded, err := agentgrpc.EncodeMessage(req) - if err != nil { - lg.WithError(err).Error("failed to encode message") - return - } var metricsList []*protocol.AgentMetric for _, bot := range bots { if !bot.ShouldProcessBlock(req.Event.Block.BlockNumber) { @@ -116,7 +110,6 @@ func (rs *requestSender) SendEvaluateTxRequest(req *protocol.EvaluateTxRequest) lg.WithField("bot", bot.Config().ID).Debug("bot is closed - skipping") case bot.TxRequestCh() <- &botreq.TxRequest{ Original: req, - Encoded: encoded, }: default: // do not try to send if the buffer is full lg.WithField("bot", bot.Config().ID).Debug("agent tx request buffer is full - skipping") @@ -148,12 +141,6 @@ func (rs *requestSender) SendEvaluateBlockRequest(req *protocol.EvaluateBlockReq bots := rs.botPool.GetCurrentBotClients() - encoded, err := agentgrpc.EncodeMessage(req) - if err != nil { - lg.WithError(err).Error("failed to encode message") - return - } - var metricsList []*protocol.AgentMetric for _, bot := range bots { if !bot.ShouldProcessBlock(req.Event.BlockNumber) { @@ -171,7 +158,6 @@ func (rs *requestSender) SendEvaluateBlockRequest(req *protocol.EvaluateBlockReq lg.WithField("bot", bot.Config().ID).Debug("bot is closed - skipping") case bot.BlockRequestCh() <- &botreq.BlockRequest{ Original: req, - Encoded: encoded, }: default: // do not try to send if the buffer is full lg.WithField("bot", bot.Config().ID).Warn("agent block request buffer is full - skipping") @@ -218,12 +204,6 @@ func (rs *requestSender) SendEvaluateAlertRequest(req *protocol.EvaluateAlertReq bots := rs.botPool.GetCurrentBotClients() - encoded, err := agentgrpc.EncodeMessage(req) - if err != nil { - lg.WithError(err).Error("failed to encode message") - return - } - var metricsList []*protocol.AgentMetric var target BotClient @@ -265,7 +245,6 @@ func (rs *requestSender) SendEvaluateAlertRequest(req *protocol.EvaluateAlertReq lg.WithField("bot", target.Config().ID).Debug("bot is closed - skipping") case target.CombinationRequestCh() <- &botreq.CombinationRequest{ Original: req, - Encoded: encoded, }: default: // do not try to send if the buffer is full lg.WithField("bot", target.Config().ID).Warn("agent alert request buffer is full - skipping") diff --git a/services/components/containers/bot_client.go b/services/components/containers/bot_client.go index 63e6f9dc..a64642c4 100644 --- a/services/components/containers/bot_client.go +++ b/services/components/containers/bot_client.go @@ -26,7 +26,7 @@ const ( type BotClient interface { EnsureBotImages(ctx context.Context, botConfigs []config.AgentConfig) []error LaunchBot(ctx context.Context, botConfig config.AgentConfig) error - TearDownBot(ctx context.Context, containerName string) error + TearDownBot(ctx context.Context, containerName string, removeImage bool) error StopBot(ctx context.Context, botConfig config.AgentConfig) error LoadBotContainers(ctx context.Context) ([]types.Container, error) StartWaitBotContainer(ctx context.Context, containerID string) error @@ -134,7 +134,7 @@ func getServiceContainerNames() []string { } // TearDownBot tears down a bot by shutting down the docker container and removing it. -func (bc *botClient) TearDownBot(ctx context.Context, containerName string) error { +func (bc *botClient) TearDownBot(ctx context.Context, containerName string, removeImage bool) error { container, err := bc.client.GetContainerByName(ctx, containerName) if err != nil { return fmt.Errorf("failed to get the bot container to tear down: %v", err) @@ -143,6 +143,7 @@ func (bc *botClient) TearDownBot(ctx context.Context, containerName string) erro if err != nil { return fmt.Errorf("failed to get service container ids during bot cleanup: %v", err) } + defer log.WithField("botContainer", containerName).Info("done tearing down the bot and the associated docker resources") // not returning any errors in `if`s below so we keep on by removing whatever is left for _, serviceContainerID := range serviceContainerIDs { if err := bc.client.DetachNetwork(ctx, serviceContainerID, containerName); err != nil { @@ -163,12 +164,14 @@ func (bc *botClient) TearDownBot(ctx context.Context, containerName string) erro "network": containerName, }).WithError(err).Warn("failed to destroy the bot network") } + if !removeImage { + return nil + } if err := bc.client.RemoveImage(ctx, container.Image); err != nil { log.WithFields(log.Fields{ "image": container.Image, }).WithError(err).Warn("failed to remove image of the destroyed bot container") } - log.WithField("botContainer", containerName).Info("done tearing down the bot and the associated docker resources") return nil } diff --git a/services/components/containers/bot_client_test.go b/services/components/containers/bot_client_test.go index 8f15ec01..47081831 100644 --- a/services/components/containers/bot_client_test.go +++ b/services/components/containers/bot_client_test.go @@ -145,7 +145,7 @@ func (s *BotClientTestSuite) TestTearDownBot() { s.client.EXPECT().RemoveNetworkByName(gomock.Any(), botConfig.ContainerName()) s.client.EXPECT().RemoveImage(gomock.Any(), testImageRef) - s.r.NoError(s.botClient.TearDownBot(context.Background(), botConfig.ContainerName())) + s.r.NoError(s.botClient.TearDownBot(context.Background(), botConfig.ContainerName(), true)) } func (s *BotClientTestSuite) TestStopBot() { diff --git a/services/components/containers/mocks/mock_bot_client.go b/services/components/containers/mocks/mock_bot_client.go index 5cee2353..60ae10c1 100644 --- a/services/components/containers/mocks/mock_bot_client.go +++ b/services/components/containers/mocks/mock_bot_client.go @@ -108,15 +108,15 @@ func (mr *MockBotClientMockRecorder) StopBot(ctx, botConfig interface{}) *gomock } // TearDownBot mocks base method. -func (m *MockBotClient) TearDownBot(ctx context.Context, containerName string) error { +func (m *MockBotClient) TearDownBot(ctx context.Context, containerName string, removeImage bool) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TearDownBot", ctx, containerName) + ret := m.ctrl.Call(m, "TearDownBot", ctx, containerName, removeImage) ret0, _ := ret[0].(error) return ret0 } // TearDownBot indicates an expected call of TearDownBot. -func (mr *MockBotClientMockRecorder) TearDownBot(ctx, containerName interface{}) *gomock.Call { +func (mr *MockBotClientMockRecorder) TearDownBot(ctx, containerName, removeImage interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TearDownBot", reflect.TypeOf((*MockBotClient)(nil).TearDownBot), ctx, containerName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TearDownBot", reflect.TypeOf((*MockBotClient)(nil).TearDownBot), ctx, containerName, removeImage) } diff --git a/services/components/lifecycle/bot_manager.go b/services/components/lifecycle/bot_manager.go index e3a58eef..7c473375 100644 --- a/services/components/lifecycle/bot_manager.go +++ b/services/components/lifecycle/bot_manager.go @@ -78,7 +78,7 @@ func (blm *botLifecycleManager) ManageBots(ctx context.Context) error { // then stop the containers for _, removedBotConfig := range removedBotConfigs { - if err := blm.botClient.TearDownBot(ctx, removedBotConfig.ContainerName()); err != nil { + if err := blm.botClient.TearDownBot(ctx, removedBotConfig.ContainerName(), true); err != nil { log.WithError(err).WithField("container", removedBotConfig.ContainerName()). Warn("failed to tear down unassigned bot container") blm.lifecycleMetrics.BotError("unassigned.teardown", err, removedBotConfig.ID) @@ -155,7 +155,7 @@ func (blm *botLifecycleManager) CleanupUnusedBots(ctx context.Context) error { continue } - if err := blm.botClient.TearDownBot(ctx, botContainerName); err != nil { + if err := blm.botClient.TearDownBot(ctx, botContainerName, true); err != nil { log.WithField("botContainer", botContainerName).WithError(err). Error("error while tearing down the unused bot") } @@ -246,7 +246,7 @@ func (blm *botLifecycleManager) TearDownRunningBots(ctx context.Context) { // then stop the containers for _, runningBotConfig := range blm.runningBots { - err := blm.botClient.TearDownBot(ctx, runningBotConfig.ContainerName()) + err := blm.botClient.TearDownBot(ctx, runningBotConfig.ContainerName(), false) if err != nil { blm.lifecycleMetrics.BotError("teardown.bot", err, runningBotConfig.ID) log.WithError(err).WithField("container", runningBotConfig.ContainerName()). diff --git a/services/components/lifecycle/bot_manager_test.go b/services/components/lifecycle/bot_manager_test.go index 5b011b70..f839470f 100644 --- a/services/components/lifecycle/bot_manager_test.go +++ b/services/components/lifecycle/bot_manager_test.go @@ -92,7 +92,7 @@ func (s *BotLifecycleManagerTestSuite) TestAddUpdateRemove() { s.botPool.EXPECT().RemoveBotsWithConfigs([]config.AgentConfig{removedBot}) s.lifecycleMetrics.EXPECT().StatusStopping([]config.AgentConfig{removedBot}) - s.botContainers.EXPECT().TearDownBot(gomock.Any(), removedBot.ContainerName()) + s.botContainers.EXPECT().TearDownBot(gomock.Any(), removedBot.ContainerName(), true) s.lifecycleMetrics.EXPECT().StatusRunning(latestAssigned).Times(1) s.botPool.EXPECT().UpdateBotsWithLatestConfigs(latestAssigned) @@ -189,7 +189,27 @@ func (s *BotLifecycleManagerTestSuite) TestCleanup() { State: "exited", }, }, nil).Times(1) - s.botContainers.EXPECT().TearDownBot(gomock.Any(), unusedBotConfig.ContainerName()).Return(nil) + s.botContainers.EXPECT().TearDownBot(gomock.Any(), unusedBotConfig.ContainerName(), true).Return(nil) s.r.NoError(s.botManager.CleanupUnusedBots(context.Background())) } + +func (s *BotLifecycleManagerTestSuite) TestTearDown() { + botConfigs := []config.AgentConfig{ + { + ID: testBotID1, + Image: testImageRef, + }, + { + ID: testBotID2, + Image: testImageRef, + }, + } + s.botManager.runningBots = botConfigs + + s.botPool.EXPECT().RemoveBotsWithConfigs(botConfigs) + s.botContainers.EXPECT().TearDownBot(gomock.Any(), botConfigs[0].ContainerName(), false).Return(nil) + s.botContainers.EXPECT().TearDownBot(gomock.Any(), botConfigs[1].ContainerName(), false).Return(nil) + + s.botManager.TearDownRunningBots(context.Background()) +} diff --git a/services/components/lifecycle/bot_pool.go b/services/components/lifecycle/bot_pool.go index 7de52b8b..df2f3314 100644 --- a/services/components/lifecycle/bot_pool.go +++ b/services/components/lifecycle/bot_pool.go @@ -153,20 +153,20 @@ func (bp *botPool) RemoveBotsWithConfigs(removedBotConfigs messaging.AgentPayloa // close and discard the removed bots for _, removedBotConfig := range removedBotConfigs { logger := botLogger(removedBotConfig) - bot, ok := bp.getBotClient(removedBotConfig.ContainerName()) + botClient, ok := bp.getBotClient(removedBotConfig.ContainerName()) if !ok { logger.Info("could not find the removed bot! skipping") continue } - _ = bot.Close() + _ = botClient.Close() } // find the bots we are not supposed to remove and keep them var preservedBots []botio.BotClient for _, preservedBotConfig := range FindExtraBots(removedBotConfigs, bp.getConfigsUnsafe()) { - bot, ok := bp.getBotClient(preservedBotConfig.ContainerName()) + botClient, ok := bp.getBotClient(preservedBotConfig.ContainerName()) if ok { - preservedBots = append(preservedBots, bot) + preservedBots = append(preservedBots, botClient) } } diff --git a/services/components/lifecycle/lifecycle_test.go b/services/components/lifecycle/lifecycle_test.go index e901c917..e30ee66e 100644 --- a/services/components/lifecycle/lifecycle_test.go +++ b/services/components/lifecycle/lifecycle_test.go @@ -391,9 +391,9 @@ func (s *LifecycleTestSuite) TestUnassigned() { s.botGrpc.EXPECT().Initialize(gomock.Any(), gomock.Any()).Return(&protocol.InitializeResponse{}, nil). Times(1) - // and should shortly be tore down + // and should shortly be torn down s.lifecycleMetrics.EXPECT().StatusStopping(assigned[0]) - s.botContainers.EXPECT().TearDownBot(gomock.Any(), assigned[0].ContainerName()).Return(nil) + s.botContainers.EXPECT().TearDownBot(gomock.Any(), assigned[0].ContainerName(), true).Return(nil) s.lifecycleMetrics.EXPECT().StatusRunning().Times(1) s.lifecycleMetrics.EXPECT().ClientClose(assigned[0]) s.botGrpc.EXPECT().Close().AnyTimes() diff --git a/services/supervisor/bots.go b/services/supervisor/bots.go index 91cc0633..256ce6ae 100644 --- a/services/supervisor/bots.go +++ b/services/supervisor/bots.go @@ -10,6 +10,7 @@ import ( // This allows us to blast the latest assignment list very often // and keep bot containers and clients in order. func (sup *SupervisorService) refreshBotContainers() { + sup.doRefreshBotContainers() for { select { case <-sup.ctx.Done():