diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/chunked_handler.go b/pkg/consensus/mimicry/p2p/reqresp/v1/chunked_handler.go index 0e942ec..cee2c67 100644 --- a/pkg/consensus/mimicry/p2p/reqresp/v1/chunked_handler.go +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/chunked_handler.go @@ -137,8 +137,22 @@ func NewChunkedHandler[TReq, TResp any]( func (h *ChunkedHandler[TReq, TResp]) HandleStream(ctx context.Context, stream network.Stream) { defer stream.Close() - // Set deadline if configured + // Recover from panics + defer func() { + if r := recover(); r != nil { + h.log.WithField("panic", r).Error("Chunked handler panicked") + _ = h.writeErrorResponse(stream, StatusServerError) + } + }() + + // Create context with timeout if configured + handlerCtx := ctx + + var cancel context.CancelFunc if h.config.RequestTimeout > 0 { + handlerCtx, cancel = context.WithTimeout(ctx, h.config.RequestTimeout) + defer cancel() + deadline := time.Now().Add(h.config.RequestTimeout) if err := stream.SetDeadline(deadline); err != nil { h.log.WithError(err).Debug("Failed to set stream deadline") @@ -168,7 +182,7 @@ func (h *ChunkedHandler[TReq, TResp]) HandleStream(ctx context.Context, stream n } // Process request with chunked writer - err = h.handler(ctx, req, peerID, writer) + err = h.handler(handlerCtx, req, peerID, writer) if err != nil { h.log.WithError(err).WithField("peer", peerID).Debug("Chunked handler returned error") // Try to send error status if writer hasn't written anything yet @@ -194,6 +208,10 @@ func (h *ChunkedHandler[TReq, TResp]) readRequest(stream network.Stream) (TReq, } size := binary.BigEndian.Uint32(sizeBytes[:]) + if size == 0 { + return req, fmt.Errorf("empty request") + } + if uint64(size) > h.protocol.MaxRequestSize() { return req, fmt.Errorf("request size %d exceeds max %d", size, h.protocol.MaxRequestSize()) } diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/chunked_handler_test.go b/pkg/consensus/mimicry/p2p/reqresp/v1/chunked_handler_test.go new file mode 100644 index 0000000..f3ad61a --- /dev/null +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/chunked_handler_test.go @@ -0,0 +1,680 @@ +package v1 + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testString = "test" + +func TestNewChunkedHandler(t *testing.T) { + proto := testChunkedProtocol{ + testProtocol: testProtocol{ + id: "/test/chunked/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + }, + chunked: true, + } + + handler := func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + return writer.WriteChunk(testResponse{Message: "chunk1", ID: req.ID}) + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{}, + RequestTimeout: 10 * time.Second, + } + + logger := logrus.New() + + h := NewChunkedHandler(proto, handler, opts, logger) + require.NotNil(t, h) + + // Verify it implements StreamHandler + var _ StreamHandler = h +} + +func TestChunkedHandler_HandleStream(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.DebugLevel) + + tests := []struct { + name string + setupStream func() *mockStream + handler ChunkedRequestHandler[testRequest, testResponse] + encoder Encoder + compressor Compressor + maxRequestSize uint64 + expectedChunks int + expectedStatus []Status + expectedMessages []string + expectedError bool + }{ + { + name: "successful_single_chunk", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/chunked/1.0.0", "remote", "local") + // Prepare request data + reqData := []byte("ping") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + return writer.WriteChunk(testResponse{Message: "pong", ID: req.ID}) + }, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return []byte("ping"), nil + }, + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = string(data) + req.ID = 1 + + return nil + } + + return nil + }, + }, + maxRequestSize: 1024, + expectedChunks: 1, + expectedStatus: []Status{StatusSuccess}, + expectedMessages: []string{"pong"}, + }, + { + name: "successful_multiple_chunks", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/chunked/1.0.0", "remote", "local") + reqData := []byte("ping") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + // Write multiple chunks + chunks := []string{"chunk1", "chunk2", "chunk3"} + for i, chunk := range chunks { + if err := writer.WriteChunk(testResponse{Message: chunk, ID: i}); err != nil { + return err + } + } + + return nil + }, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return []byte("ping"), nil + }, + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = string(data) + req.ID = 1 + + return nil + } + + return nil + }, + }, + maxRequestSize: 1024, + expectedChunks: 3, + expectedStatus: []Status{StatusSuccess, StatusSuccess, StatusSuccess}, + expectedMessages: []string{"chunk1", "chunk2", "chunk3"}, + }, + { + name: "handler_error", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/chunked/1.0.0", "remote", "local") + reqData := []byte("ping") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + return errors.New("handler error") + }, + encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = string(data) + req.ID = 1 + + return nil + } + + return nil + }, + }, + maxRequestSize: 1024, + expectedChunks: 1, + expectedStatus: []Status{StatusServerError}, + expectedError: false, + }, + { + name: "write_chunk_after_error", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/chunked/1.0.0", "remote", "local") + reqData := []byte("ping") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + // Write first chunk successfully + if err := writer.WriteChunk(testResponse{Message: "chunk1", ID: 1}); err != nil { + return err + } + // Simulate an error occurring + // The writer should handle subsequent writes gracefully + return errors.New("error after first chunk") + }, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return []byte("ping"), nil + }, + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = string(data) + req.ID = 1 + + return nil + } + + return nil + }, + }, + maxRequestSize: 1024, + expectedChunks: 2, + expectedStatus: []Status{StatusSuccess, StatusServerError}, + expectedMessages: []string{"chunk1"}, + expectedError: false, + }, + { + name: "with_compression", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/chunked/1.0.0", "remote", "local") + // Prepare compressed request + reqData := []byte("COMPRESSED:ping") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + return writer.WriteChunk(testResponse{Message: "pong", ID: req.ID}) + }, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return []byte("ping"), nil + }, + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = string(data) + req.ID = 1 + + return nil + } + + return nil + }, + }, + compressor: &mockCompressor{}, + maxRequestSize: 1024, + expectedChunks: 1, + expectedStatus: []Status{StatusSuccess}, + expectedMessages: []string{"COMPRESSED:pong"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stream := tt.setupStream() + ctx := context.Background() + + proto := testChunkedProtocol{ + testProtocol: testProtocol{ + id: "/test/chunked/1.0.0", + maxRequestSize: tt.maxRequestSize, + maxResponseSize: 2048, + }, + chunked: true, + } + + opts := HandlerOptions{ + Encoder: tt.encoder, + Compressor: tt.compressor, + RequestTimeout: 5 * time.Second, + } + + h := NewChunkedHandler(proto, tt.handler, opts, logger) + + // Handle the stream + h.HandleStream(ctx, stream) + + // Parse written data to extract chunks + written := stream.getWrittenData() + chunks := parseChunkedResponse(t, written) + + // Verify chunk count + assert.Equal(t, tt.expectedChunks, len(chunks)) + + // Verify each chunk + for i, chunk := range chunks { + if i < len(tt.expectedStatus) { + assert.Equal(t, tt.expectedStatus[i], chunk.status) + } + if i < len(tt.expectedMessages) { + assert.Equal(t, tt.expectedMessages[i], string(chunk.data)) + } + } + + // If we expect an error status at the end + if tt.expectedError && len(written) > 0 { + // The last byte might be an error status if handler returned error + lastByte := written[len(written)-1] + if lastByte == byte(StatusServerError) { + // This is expected for handler errors + assert.True(t, true) + } + } + }) + } +} + +type parsedChunk struct { + status Status + data []byte +} + +func parseChunkedResponse(t *testing.T, data []byte) []parsedChunk { + t.Helper() + + var chunks []parsedChunk + offset := 0 + + for offset < len(data) { + // Read status byte + if offset >= len(data) { + break + } + status := Status(data[offset]) + offset++ + + // If error status, no data follows + if status != StatusSuccess { + chunks = append(chunks, parsedChunk{status: status}) + + continue + } + + // Read size + if offset+4 > len(data) { + break + } + size := binary.BigEndian.Uint32(data[offset : offset+4]) + offset += 4 + + // Read data + if offset+int(size) > len(data) { + break + } + chunkData := data[offset : offset+int(size)] + offset += int(size) + + chunks = append(chunks, parsedChunk{ + status: status, + data: chunkData, + }) + } + + return chunks +} + +func TestChunkedResponseWriter(t *testing.T) { + logger := logrus.New() + + tests := []struct { + name string + setupWriter func() *streamChunkedWriter[testResponse] + chunks []testResponse + expectedError string + verifyWrite func(t *testing.T, stream *mockStream) + }{ + { + name: "write_single_chunk", + setupWriter: func() *streamChunkedWriter[testResponse] { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + + return &streamChunkedWriter[testResponse]{ + stream: stream, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return nil, errors.New("unknown type") + }, + }, + maxSize: 1024, + log: logger, + } + }, + chunks: []testResponse{ + {Message: "test chunk", ID: 1}, + }, + verifyWrite: func(t *testing.T, stream *mockStream) { + t.Helper() + data := stream.getWrittenData() + chunks := parseChunkedResponse(t, data) + require.Equal(t, 1, len(chunks)) + assert.Equal(t, StatusSuccess, chunks[0].status) + assert.Equal(t, "test chunk", string(chunks[0].data)) + }, + }, + { + name: "write_multiple_chunks", + setupWriter: func() *streamChunkedWriter[testResponse] { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + + return &streamChunkedWriter[testResponse]{ + stream: stream, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return nil, errors.New("unknown type") + }, + }, + maxSize: 1024, + log: logger, + } + }, + chunks: []testResponse{ + {Message: "chunk1", ID: 1}, + {Message: "chunk2", ID: 2}, + {Message: "chunk3", ID: 3}, + }, + verifyWrite: func(t *testing.T, stream *mockStream) { + t.Helper() + data := stream.getWrittenData() + chunks := parseChunkedResponse(t, data) + require.Equal(t, 3, len(chunks)) + for i, chunk := range chunks { + assert.Equal(t, StatusSuccess, chunk.status) + assert.Equal(t, fmt.Sprintf("chunk%d", i+1), string(chunk.data)) + } + }, + }, + { + name: "chunk_exceeds_max_size", + setupWriter: func() *streamChunkedWriter[testResponse] { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + + return &streamChunkedWriter[testResponse]{ + stream: stream, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + // Return data that exceeds max size + return make([]byte, 2048), nil + }, + }, + maxSize: 1024, + log: logger, + } + }, + chunks: []testResponse{ + {Message: "too large", ID: 1}, + }, + expectedError: "exceeds max", + }, + { + name: "encoder_error", + setupWriter: func() *streamChunkedWriter[testResponse] { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + + return &streamChunkedWriter[testResponse]{ + stream: stream, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + return nil, errors.New("encode error") + }, + }, + maxSize: 1024, + log: logger, + } + }, + chunks: []testResponse{ + {Message: "test", ID: 1}, + }, + expectedError: "encode error", + }, + { + name: "with_compression", + setupWriter: func() *streamChunkedWriter[testResponse] { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + + return &streamChunkedWriter[testResponse]{ + stream: stream, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return nil, errors.New("unknown type") + }, + }, + compressor: &mockCompressor{}, + maxSize: 1024, + log: logger, + } + }, + chunks: []testResponse{ + {Message: "test chunk", ID: 1}, + }, + verifyWrite: func(t *testing.T, stream *mockStream) { + t.Helper() + data := stream.getWrittenData() + chunks := parseChunkedResponse(t, data) + require.Equal(t, 1, len(chunks)) + assert.Equal(t, StatusSuccess, chunks[0].status) + assert.Equal(t, "COMPRESSED:test chunk", string(chunks[0].data)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + writer := tt.setupWriter() + + var err error + for _, chunk := range tt.chunks { + if e := writer.WriteChunk(chunk); e != nil { + err = e + + break + } + } + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + if tt.verifyWrite != nil { + if stream, ok := writer.stream.(*mockStream); ok { + tt.verifyWrite(t, stream) + } + } + } + }) + } +} + +func TestChunkedHandler_TimeoutHandling(t *testing.T) { + proto := testChunkedProtocol{ + testProtocol: testProtocol{ + id: "/test/chunked/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + }, + chunked: true, + } + + // Handler that takes longer than timeout + slowHandler := func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + select { + case <-time.After(200 * time.Millisecond): + return writer.WriteChunk(testResponse{Message: "too late"}) + case <-ctx.Done(): + return ctx.Err() + } + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = testString + + return nil + } + + return nil + }, + }, + RequestTimeout: 50 * time.Millisecond, // Very short timeout + } + + logger := logrus.New() + h := NewChunkedHandler(proto, slowHandler, opts, logger) + + // Setup stream with valid request + stream := newMockStream("test-stream", "/test/chunked/1.0.0", "remote", "local") + reqData := []byte(testString + " request") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + ctx := context.Background() + h.HandleStream(ctx, stream) + + // Verify error response was written + written := stream.getWrittenData() + require.GreaterOrEqual(t, len(written), 1) + assert.Equal(t, byte(StatusServerError), written[0]) +} + +func TestChunkedHandler_PanicRecovery(t *testing.T) { + proto := testChunkedProtocol{ + testProtocol: testProtocol{ + id: "/test/chunked/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + }, + chunked: true, + } + + // Handler that panics + panicHandler := func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + panic("handler panic") + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = testString + + return nil + } + + return nil + }, + }, + RequestTimeout: 5 * time.Second, + } + + logger := logrus.New() + h := NewChunkedHandler(proto, panicHandler, opts, logger) + + // Setup stream with valid request + stream := newMockStream("test-stream", "/test/chunked/1.0.0", "remote", "local") + reqData := []byte(testString + " request") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + ctx := context.Background() + + // Should not panic + assert.NotPanics(t, func() { + h.HandleStream(ctx, stream) + }) + + // Verify error response was written + written := stream.getWrittenData() + require.GreaterOrEqual(t, len(written), 1) + assert.Equal(t, byte(StatusServerError), written[0]) +} diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/client_test.go b/pkg/consensus/mimicry/p2p/reqresp/v1/client_test.go new file mode 100644 index 0000000..8592b7c --- /dev/null +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/client_test.go @@ -0,0 +1,611 @@ +package v1 + +import ( + "context" + "encoding/binary" + "errors" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testRequestString = "test request" + +func TestNewClient(t *testing.T) { + host := newMockHost("test-peer") + config := ClientConfig{ + DefaultTimeout: 10 * time.Second, + MaxRetries: 3, + RetryDelay: 100 * time.Millisecond, + } + logger := logrus.New() + + client := NewClient(host, config, logger) + require.NotNil(t, client) + + // Verify client is not nil and implements the interface + var _ = client +} + +func TestClient_SendRequest(t *testing.T) { + ctx := context.Background() + host := newMockHost("test-peer") + config := ClientConfig{ + DefaultTimeout: 1 * time.Second, + MaxRetries: 0, + } + logger := logrus.New() + logger.SetLevel(logrus.DebugLevel) + + client := NewClient(host, config, logger) + + tests := []struct { + name string + setupStream func() *mockStream + encoder Encoder + expectedError string + }{ + { + name: "successful_request", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + // Prepare response data + respData := []byte("test response") + var buf []byte + buf = append(buf, byte(StatusSuccess)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(respData))) + buf = append(buf, sizeBuf...) + buf = append(buf, respData...) + stream.setReadData(buf) + + return stream + }, + encoder: &mockEncoder{}, + }, + { + name: "stream_creation_fails", + setupStream: func() *mockStream { + return nil + }, + encoder: &mockEncoder{}, + expectedError: "failed to open stream", + }, + { + name: "encoder_not_provided", + setupStream: func() *mockStream { + return newMockStream("test-stream", "/test/1.0.0", "local", "remote") + }, + encoder: nil, + expectedError: "encoder must be provided", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup mock host behavior + if tt.setupStream != nil { + stream := tt.setupStream() + host.newStreamFunc = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + if stream == nil { + return nil, errors.New("stream creation failed") + } + + return stream, nil + } + } + + req := testRequestString + var resp string + + opts := RequestOptions{ + Encoder: tt.encoder, + } + + err := client.SendRequestWithOptions(ctx, "remote-peer", "/test/1.0.0", &req, &resp, opts) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + assert.Equal(t, "test response", resp) + } + }) + } +} + +func TestClient_SendRequestWithTimeout(t *testing.T) { + ctx := context.Background() + host := newMockHost("test-peer") + config := ClientConfig{ + DefaultTimeout: 1 * time.Second, + MaxRetries: 0, + } + logger := logrus.New() + + client := NewClient(host, config, logger) + + // Test that custom timeout is applied + customTimeout := 500 * time.Millisecond + startTime := time.Now() + + // Setup a stream that delays response + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + host.newStreamFunc = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + // Check if context has timeout + deadline, ok := ctx.Deadline() + if ok { + timeUntilDeadline := time.Until(deadline) + // Verify timeout is approximately what we set + assert.InDelta(t, customTimeout.Milliseconds(), timeUntilDeadline.Milliseconds(), 100) + } + + return stream, nil + } + + req := testRequestString + var resp string + + err := client.SendRequestWithTimeout(ctx, "remote-peer", "/test/1.0.0", &req, &resp, customTimeout) + elapsed := time.Since(startTime) + + // Should fail because no encoder is provided by default + require.Error(t, err) + assert.Contains(t, err.Error(), "encoder must be provided") + assert.Less(t, elapsed, customTimeout+100*time.Millisecond) +} + +func TestClient_RetryLogic(t *testing.T) { + ctx := context.Background() + host := newMockHost("test-peer") + config := ClientConfig{ + DefaultTimeout: 500 * time.Millisecond, + MaxRetries: 2, + RetryDelay: 50 * time.Millisecond, + } + logger := logrus.New() + logger.SetLevel(logrus.DebugLevel) + + client := NewClient(host, config, logger) + + attemptCount := 0 + host.newStreamFunc = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + attemptCount++ + if attemptCount <= 2 { + return nil, errors.New("temporary failure") + } + // Success on third attempt + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + // Prepare successful response + respData := []byte("success after retries") + var buf []byte + buf = append(buf, byte(StatusSuccess)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(respData))) + buf = append(buf, sizeBuf...) + buf = append(buf, respData...) + stream.setReadData(buf) + + return stream, nil + } + + req := testRequestString + var resp string + + opts := RequestOptions{ + Encoder: &mockEncoder{}, + } + + err := client.SendRequestWithOptions(ctx, "remote-peer", "/test/1.0.0", &req, &resp, opts) + require.NoError(t, err) + assert.Equal(t, "success after retries", resp) + assert.Equal(t, 3, attemptCount) +} + +func TestClient_WriteRequest(t *testing.T) { + client := &client{ + log: logrus.New(), + } + + tests := []struct { + name string + request any + encoder Encoder + compressor Compressor + expectedError string + verifyWrite func(t *testing.T, stream *mockStream) + }{ + { + name: "successful_write_no_compression", + request: testRequestString, + encoder: &mockEncoder{}, + verifyWrite: func(t *testing.T, stream *mockStream) { + t.Helper() + data := stream.getWrittenData() + require.GreaterOrEqual(t, len(data), 4) + + // Check size prefix + size := binary.BigEndian.Uint32(data[:4]) + assert.Equal(t, uint32(len(testRequestString)), size) + + // Check data + assert.Equal(t, testRequestString, string(data[4:])) + }, + }, + { + name: "successful_write_with_compression", + request: "test request", + encoder: &mockEncoder{}, + compressor: &mockCompressor{}, + verifyWrite: func(t *testing.T, stream *mockStream) { + t.Helper() + data := stream.getWrittenData() + require.GreaterOrEqual(t, len(data), 4) + + // Check size prefix + size := binary.BigEndian.Uint32(data[:4]) + expectedCompressed := "COMPRESSED:" + testRequestString + assert.Equal(t, uint32(len(expectedCompressed)), size) + + // Check compressed data + assert.Equal(t, expectedCompressed, string(data[4:])) + }, + }, + { + name: "encoder_fails", + request: testRequestString, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + return nil, errors.New("encode error") + }, + }, + expectedError: "failed to encode request", + }, + { + name: "compressor_fails", + request: testRequestString, + encoder: &mockEncoder{}, + compressor: &mockCompressor{ + compressFunc: func(data []byte) ([]byte, error) { + return nil, errors.New("compress error") + }, + }, + expectedError: "failed to compress request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + opts := RequestOptions{ + Encoder: tt.encoder, + Compressor: tt.compressor, + } + + err := client.writeRequest(stream, tt.request, opts) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + if tt.verifyWrite != nil { + tt.verifyWrite(t, stream) + } + } + }) + } +} + +func TestClient_ReadResponse(t *testing.T) { + client := &client{ + log: logrus.New(), + } + + tests := []struct { + name string + setupStream func() *mockStream + encoder Encoder + compressor Compressor + expectedResp string + expectedError string + }{ + { + name: "successful_read_no_compression", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + respData := []byte("test response") + var buf []byte + buf = append(buf, byte(StatusSuccess)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(respData))) + buf = append(buf, sizeBuf...) + buf = append(buf, respData...) + stream.setReadData(buf) + + return stream + }, + encoder: &mockEncoder{}, + expectedResp: "test response", + }, + { + name: "successful_read_with_compression", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + respData := []byte("COMPRESSED:test response") + var buf []byte + buf = append(buf, byte(StatusSuccess)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(respData))) + buf = append(buf, sizeBuf...) + buf = append(buf, respData...) + stream.setReadData(buf) + + return stream + }, + encoder: &mockEncoder{}, + compressor: &mockCompressor{}, + expectedResp: "test response", + }, + { + name: "error_status", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + stream.setReadData([]byte{byte(StatusServerError)}) + + return stream + }, + encoder: &mockEncoder{}, + expectedError: "server returned error status: server_error", + }, + { + name: "empty_response", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + var buf []byte + buf = append(buf, byte(StatusSuccess)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, 0) + buf = append(buf, sizeBuf...) + stream.setReadData(buf) + + return stream + }, + encoder: &mockEncoder{}, + expectedError: "received empty response", + }, + { + name: "decoder_fails", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + respData := []byte("test response") + var buf []byte + buf = append(buf, byte(StatusSuccess)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(respData))) + buf = append(buf, sizeBuf...) + buf = append(buf, respData...) + stream.setReadData(buf) + + return stream + }, + encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + return errors.New("decode error") + }, + }, + expectedError: "failed to decode response", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stream := tt.setupStream() + opts := RequestOptions{ + Encoder: tt.encoder, + Compressor: tt.compressor, + } + + var resp string + err := client.readResponse(stream, &resp, opts) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedResp, resp) + } + }) + } +} + +func TestClient_ContextCancellation(t *testing.T) { + host := newMockHost("test-peer") + config := ClientConfig{ + DefaultTimeout: 5 * time.Second, + MaxRetries: 2, + RetryDelay: 100 * time.Millisecond, + } + logger := logrus.New() + + client := NewClient(host, config, logger) + + // Create a context that we'll cancel + ctx, cancel := context.WithCancel(context.Background()) + + attemptCount := 0 + host.newStreamFunc = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + attemptCount++ + if attemptCount == 1 { + // Cancel context after first attempt + cancel() + } + + return nil, errors.New("temporary failure") + } + + req := testRequestString + var resp string + + opts := RequestOptions{ + Encoder: &mockEncoder{}, + } + + err := client.SendRequestWithOptions(ctx, "remote-peer", "/test/1.0.0", &req, &resp, opts) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + assert.Equal(t, 1, attemptCount) // Should not retry after context cancellation +} + +func TestRequest_FluentAPI(t *testing.T) { + host := newMockHost("test-peer") + config := ClientConfig{ + DefaultTimeout: 1 * time.Second, + } + logger := logrus.New() + + client := NewClient(host, config, logger) + proto := testProtocol{ + id: "/test/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 1024, + } + + t.Run("successful_request", func(t *testing.T) { + req := NewRequest[testRequest, testResponse](client, proto) + require.NotNil(t, req) + + // Test fluent API + req = req.To("remote-peer").WithTimeout(500 * time.Millisecond) + assert.Equal(t, peer.ID("remote-peer"), req.peerID) + assert.Equal(t, 500*time.Millisecond, req.timeout) + }) + + t.Run("missing_peer_id", func(t *testing.T) { + req := NewRequest[testRequest, testResponse](client, proto) + + ctx := context.Background() + testReq := testRequest{Message: "test", ID: 1} + + _, err := req.Send(ctx, testReq) + require.Error(t, err) + assert.Contains(t, err.Error(), "peer ID not set") + }) +} + +func TestChunkedClient(t *testing.T) { + host := newMockHost("test-peer") + config := ClientConfig{ + DefaultTimeout: 1 * time.Second, + MaxRetries: 0, + } + logger := logrus.New() + + chunkedClient := NewChunkedClient(host, config, logger) + require.NotNil(t, chunkedClient) + + t.Run("send_chunked_request", func(t *testing.T) { + // Setup mock stream with multiple chunks + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + + // Prepare multiple chunks + chunks := []string{"chunk1", "chunk2", "chunk3"} + var buf []byte + for _, chunk := range chunks { + buf = append(buf, byte(StatusSuccess)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(chunk))) + buf = append(buf, sizeBuf...) + buf = append(buf, chunk...) + } + stream.setReadData(buf) + + host.newStreamFunc = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + return stream, nil + } + + ctx := context.Background() + req := testRequestString + + receivedChunks := []string{} + chunkHandler := func(chunk any) error { + // In real implementation, chunk would be decoded + // For now, we just store the raw data + if data, ok := chunk.([]byte); ok { + receivedChunks = append(receivedChunks, string(data)) + } + + return nil + } + + opts := RequestOptions{ + Encoder: &mockEncoder{}, + } + + err := chunkedClient.SendChunkedRequestWithOptions(ctx, "remote-peer", "/test/1.0.0", &req, chunkHandler, opts) + require.NoError(t, err) + assert.Equal(t, 3, len(receivedChunks)) + }) + + t.Run("chunk_handler_error", func(t *testing.T) { + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + + // Prepare a chunk + var buf []byte + buf = append(buf, byte(StatusSuccess)) + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len("chunk1"))) + buf = append(buf, sizeBuf...) + buf = append(buf, []byte("chunk1")...) + stream.setReadData(buf) + + host.newStreamFunc = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + return stream, nil + } + + ctx := context.Background() + req := testRequestString + + chunkHandler := func(chunk any) error { + return errors.New("handler error") + } + + opts := RequestOptions{ + Encoder: &mockEncoder{}, + } + + err := chunkedClient.SendChunkedRequestWithOptions(ctx, "remote-peer", "/test/1.0.0", &req, chunkHandler, opts) + require.Error(t, err) + assert.Contains(t, err.Error(), "chunk handler error") + }) +} + +func TestClient_DataSizeValidation(t *testing.T) { + client := &client{ + log: logrus.New(), + } + + // Create a very large message that would exceed uint32 max when encoded + largeData := make([]byte, 1<<32) // 4GB + + stream := newMockStream("test-stream", "/test/1.0.0", "local", "remote") + opts := RequestOptions{ + Encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + return largeData, nil + }, + }, + } + + err := client.writeRequest(stream, "test", opts) + require.Error(t, err) + assert.Contains(t, err.Error(), "data size") +} diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/example_test.go b/pkg/consensus/mimicry/p2p/reqresp/v1/example_test.go index 56491f7..c1dfee4 100644 --- a/pkg/consensus/mimicry/p2p/reqresp/v1/example_test.go +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/example_test.go @@ -266,6 +266,7 @@ func Example_middleware() { if err != nil { return "", err } + return fmt.Sprintf("[Rate-Limited] %s", resp), nil } diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/handler.go b/pkg/consensus/mimicry/p2p/reqresp/v1/handler.go index 399f755..f858f2a 100644 --- a/pkg/consensus/mimicry/p2p/reqresp/v1/handler.go +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/handler.go @@ -43,8 +43,22 @@ func NewHandler[TReq, TResp any]( func (h *Handler[TReq, TResp]) HandleStream(ctx context.Context, stream network.Stream) { defer stream.Close() - // Set deadline if configured + // Recover from panics + defer func() { + if r := recover(); r != nil { + h.log.WithField("panic", r).Error("Handler panicked") + _ = h.writeErrorResponse(stream, StatusServerError) + } + }() + + // Create context with timeout if configured + handlerCtx := ctx + + var cancel context.CancelFunc if h.config.RequestTimeout > 0 { + handlerCtx, cancel = context.WithTimeout(ctx, h.config.RequestTimeout) + defer cancel() + deadline := time.Now().Add(h.config.RequestTimeout) if err := stream.SetDeadline(deadline); err != nil { h.log.WithError(err).Debug("Failed to set stream deadline") @@ -65,7 +79,7 @@ func (h *Handler[TReq, TResp]) HandleStream(ctx context.Context, stream network. } // Process request - resp, err := h.handler(ctx, req, peerID) + resp, err := h.handler(handlerCtx, req, peerID) if err != nil { h.log.WithError(err).WithField("peer", peerID).Debug("Handler returned error") _ = h.writeErrorResponse(stream, StatusServerError) @@ -90,6 +104,10 @@ func (h *Handler[TReq, TResp]) readRequest(stream network.Stream) (TReq, error) } size := binary.BigEndian.Uint32(sizeBytes[:]) + if size == 0 { + return req, fmt.Errorf("empty request") + } + if uint64(size) > h.protocol.MaxRequestSize() { return req, fmt.Errorf("request size %d exceeds max %d", size, h.protocol.MaxRequestSize()) } diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/handler_test.go b/pkg/consensus/mimicry/p2p/reqresp/v1/handler_test.go new file mode 100644 index 0000000..4671c8c --- /dev/null +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/handler_test.go @@ -0,0 +1,598 @@ +package v1 + +import ( + "context" + "encoding/binary" + "errors" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewHandler(t *testing.T) { + proto := testProtocol{ + id: "/test/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + } + + handler := func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + return testResponse{Message: "pong", ID: req.ID}, nil + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{}, + RequestTimeout: 10 * time.Second, + } + + logger := logrus.New() + + h := NewHandler(proto, handler, opts, logger) + require.NotNil(t, h) + + // Verify it implements StreamHandler + var _ StreamHandler = h +} + +func TestHandler_HandleStream(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.DebugLevel) + + tests := []struct { + name string + setupStream func() *mockStream + handler RequestHandler[testRequest, testResponse] + encoder Encoder + compressor Compressor + maxRequestSize uint64 + expectedStatus Status + expectedResp string + }{ + { + name: "successful_request", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "remote", "local") + // Prepare request data + reqData := []byte(`{"Message":"ping","ID":1}`) + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + return testResponse{Message: "pong", ID: req.ID, Time: time.Now()}, nil + }, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if req, ok := msg.(testRequest); ok { + return []byte(`{"Message":"` + req.Message + `","ID":1}`), nil + } + if resp, ok := msg.(testResponse); ok { + return []byte(`{"Message":"` + resp.Message + `","ID":1}`), nil + } + + return nil, errors.New("unknown type") + }, + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = "ping" + req.ID = 1 + + return nil + } + + return errors.New("unknown type") + }, + }, + maxRequestSize: 1024, + expectedStatus: StatusSuccess, + expectedResp: `{"Message":"pong","ID":1}`, + }, + { + name: "request_too_large", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "remote", "local") + // Prepare oversized request + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, 2048) // Exceeds max size + buf = append(buf, sizeBuf...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + return testResponse{}, nil + }, + encoder: &mockEncoder{}, + maxRequestSize: 1024, + expectedStatus: StatusInvalidRequest, + }, + { + name: "handler_returns_error", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "remote", "local") + reqData := []byte(`{"Message":"ping","ID":1}`) + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + return testResponse{}, errors.New("handler error") + }, + encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = "ping" + req.ID = 1 + + return nil + } + + return errors.New("unknown type") + }, + }, + maxRequestSize: 1024, + expectedStatus: StatusServerError, + }, + { + name: "decode_error", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "remote", "local") + reqData := []byte("invalid json") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + return testResponse{}, nil + }, + encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + return errors.New("decode error") + }, + }, + maxRequestSize: 1024, + expectedStatus: StatusInvalidRequest, + }, + { + name: "with_compression", + setupStream: func() *mockStream { + stream := newMockStream("test-stream", "/test/1.0.0", "remote", "local") + // Prepare compressed request + reqData := []byte("COMPRESSED:ping") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + return stream + }, + handler: func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + return testResponse{Message: "pong", ID: req.ID}, nil + }, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return []byte("ping"), nil + }, + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = string(data) + req.ID = 1 + + return nil + } + + return nil + }, + }, + compressor: &mockCompressor{}, + maxRequestSize: 1024, + expectedStatus: StatusSuccess, + expectedResp: "COMPRESSED:pong", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stream := tt.setupStream() + ctx := context.Background() + + proto := testProtocol{ + id: "/test/1.0.0", + maxRequestSize: tt.maxRequestSize, + maxResponseSize: 2048, + } + + opts := HandlerOptions{ + Encoder: tt.encoder, + Compressor: tt.compressor, + RequestTimeout: 5 * time.Second, + } + + h := NewHandler(proto, tt.handler, opts, logger) + + // Handle the stream + h.HandleStream(ctx, stream) + + // Check what was written to the stream + written := stream.getWrittenData() + require.GreaterOrEqual(t, len(written), 1, "Should have written at least status byte") + + // Check status + status := Status(written[0]) + assert.Equal(t, tt.expectedStatus, status) + + if tt.expectedStatus == StatusSuccess && tt.expectedResp != "" { + // Check response data + require.GreaterOrEqual(t, len(written), 5, "Should have status + size + data") + size := binary.BigEndian.Uint32(written[1:5]) + require.Equal(t, len(written)-5, int(size), "Size should match data length") + + respData := written[5:] + assert.Equal(t, tt.expectedResp, string(respData)) + } + }) + } +} + +func TestHandler_ReadRequest(t *testing.T) { + h := &Handler[testRequest, testResponse]{ + log: logrus.New(), + protocol: testProtocol{ + id: "/test/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + }, + } + + tests := []struct { + name string + setupStream func() network.Stream + maxSize uint64 + encoder Encoder + compressor Compressor + expectedReq testRequest + expectedError string + }{ + { + name: "successful_read", + setupStream: func() network.Stream { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + data := []byte("test request") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(data))) + buf = append(buf, sizeBuf...) + buf = append(buf, data...) + stream.setReadData(buf) + + return stream + }, + maxSize: 1024, + encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = string(data) + req.ID = 123 + + return nil + } + + return errors.New("unknown type") + }, + }, + expectedReq: testRequest{Message: "test request", ID: 123}, + }, + { + name: "size_exceeds_max", + setupStream: func() network.Stream { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, 2048) + buf = append(buf, sizeBuf...) + stream.setReadData(buf) + + return stream + }, + maxSize: 1024, + encoder: &mockEncoder{}, + expectedError: "exceeds max", + }, + { + name: "empty_request", + setupStream: func() network.Stream { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, 0) + buf = append(buf, sizeBuf...) + stream.setReadData(buf) + + return stream + }, + maxSize: 1024, + encoder: &mockEncoder{}, + expectedError: "empty request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stream := tt.setupStream() + h.encoder = tt.encoder + h.compressor = tt.compressor + h.protocol = testProtocol{ + id: "/test/1.0.0", + maxRequestSize: tt.maxSize, + maxResponseSize: 2048, + } + + req, err := h.readRequest(stream) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedReq, req) + } + }) + } +} + +func TestHandler_WriteResponse(t *testing.T) { + h := &Handler[testRequest, testResponse]{ + log: logrus.New(), + protocol: testProtocol{ + id: "/test/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + }, + } + + tests := []struct { + name string + response testResponse + encoder Encoder + compressor Compressor + expectedError string + verifyWrite func(t *testing.T, data []byte) + }{ + { + name: "successful_write", + response: testResponse{Message: "test response", ID: 123}, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if resp, ok := msg.(testResponse); ok { + return []byte(resp.Message), nil + } + + return nil, errors.New("unknown type") + }, + }, + verifyWrite: func(t *testing.T, data []byte) { + t.Helper() + require.GreaterOrEqual(t, len(data), 5) + assert.Equal(t, byte(StatusSuccess), data[0]) + size := binary.BigEndian.Uint32(data[1:5]) + assert.Equal(t, uint32(len("test response")), size) + assert.Equal(t, "test response", string(data[5:])) + }, + }, + { + name: "encode_error", + response: testResponse{Message: "test", ID: 123}, + encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + return nil, errors.New("encode error") + }, + }, + expectedError: "encode error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + h.encoder = tt.encoder + h.compressor = tt.compressor + + err := h.writeResponse(stream, StatusSuccess, tt.response) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + if tt.verifyWrite != nil { + data := stream.getWrittenData() + tt.verifyWrite(t, data) + } + } + }) + } +} + +func TestHandler_WriteErrorResponse(t *testing.T) { + h := &Handler[testRequest, testResponse]{ + log: logrus.New(), + protocol: testProtocol{ + id: "/test/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + }, + } + + tests := []struct { + name string + status Status + verifyWrite func(t *testing.T, data []byte) + }{ + { + name: "invalid_request_error", + status: StatusInvalidRequest, + verifyWrite: func(t *testing.T, data []byte) { + t.Helper() + require.Equal(t, 1, len(data)) + assert.Equal(t, byte(StatusInvalidRequest), data[0]) + }, + }, + { + name: "server_error", + status: StatusServerError, + verifyWrite: func(t *testing.T, data []byte) { + t.Helper() + require.Equal(t, 1, len(data)) + assert.Equal(t, byte(StatusServerError), data[0]) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stream := newMockStream("test", "/test/1.0.0", "local", "remote") + // Test the standalone writeErrorResponse which just writes status + err := h.writeResponse(stream, tt.status, testResponse{}) + require.NoError(t, err) + + if tt.verifyWrite != nil { + data := stream.getWrittenData() + tt.verifyWrite(t, data) + } + }) + } +} + +func TestHandler_TimeoutHandling(t *testing.T) { + proto := testProtocol{ + id: "/test/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + } + + // Handler that takes longer than timeout + slowHandler := func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + select { + case <-time.After(200 * time.Millisecond): + return testResponse{Message: "too late"}, nil + case <-ctx.Done(): + return testResponse{}, ctx.Err() + } + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = "test" + + return nil + } + + return nil + }, + }, + RequestTimeout: 50 * time.Millisecond, // Very short timeout + } + + logger := logrus.New() + h := NewHandler(proto, slowHandler, opts, logger) + + // Setup stream with valid request + stream := newMockStream("test-stream", "/test/1.0.0", "remote", "local") + reqData := []byte("test request") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + ctx := context.Background() + h.HandleStream(ctx, stream) + + // Verify error response was written + written := stream.getWrittenData() + require.GreaterOrEqual(t, len(written), 1) + assert.Equal(t, byte(StatusServerError), written[0]) +} + +func TestHandler_PanicRecovery(t *testing.T) { + proto := testProtocol{ + id: "/test/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + } + + // Handler that panics + panicHandler := func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + panic("handler panic") + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{ + decodeFunc: func(data []byte, msgType any) error { + if req, ok := msgType.(*testRequest); ok { + req.Message = "test" + + return nil + } + + return nil + }, + }, + RequestTimeout: 5 * time.Second, + } + + logger := logrus.New() + h := NewHandler(proto, panicHandler, opts, logger) + + // Setup stream with valid request + stream := newMockStream("test-stream", "/test/1.0.0", "remote", "local") + reqData := []byte("test request") + var buf []byte + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqData))) + buf = append(buf, sizeBuf...) + buf = append(buf, reqData...) + stream.setReadData(buf) + + ctx := context.Background() + + // Should not panic + assert.NotPanics(t, func() { + h.HandleStream(ctx, stream) + }) + + // Verify error response was written + written := stream.getWrittenData() + require.GreaterOrEqual(t, len(written), 1) + assert.Equal(t, byte(StatusServerError), written[0]) +} diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/mocks_test.go b/pkg/consensus/mimicry/p2p/reqresp/v1/mocks_test.go new file mode 100644 index 0000000..36bf5a9 --- /dev/null +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/mocks_test.go @@ -0,0 +1,658 @@ +package v1 + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/connmgr" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" + ma "github.com/multiformats/go-multiaddr" +) + +// mockHost implements a mock libp2p host for testing. +type mockHost struct { + mu sync.RWMutex + id peer.ID + streams map[string]*mockStream + handlers map[protocol.ID]network.StreamHandler + newStreamFunc func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) +} + +func newMockHost(id peer.ID) *mockHost { + return &mockHost{ + id: id, + streams: make(map[string]*mockStream), + handlers: make(map[protocol.ID]network.StreamHandler), + } +} + +func (h *mockHost) ID() peer.ID { + return h.id +} + +func (h *mockHost) Peerstore() peerstore.Peerstore { + return nil +} + +func (h *mockHost) Addrs() []ma.Multiaddr { + return nil +} + +func (h *mockHost) Network() network.Network { + return nil +} + +func (h *mockHost) Mux() protocol.Switch { + return nil +} + +func (h *mockHost) Connect(ctx context.Context, pi peer.AddrInfo) error { + return nil +} + +func (h *mockHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) { + h.mu.Lock() + defer h.mu.Unlock() + h.handlers[pid] = handler +} + +func (h *mockHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) { + h.SetStreamHandler(pid, handler) +} + +func (h *mockHost) RemoveStreamHandler(pid protocol.ID) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.handlers, pid) +} + +func (h *mockHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + h.mu.Lock() + defer h.mu.Unlock() + + if h.newStreamFunc != nil { + return h.newStreamFunc(ctx, p, pids...) + } + + if len(pids) == 0 { + return nil, errors.New("no protocol specified") + } + + streamID := fmt.Sprintf("%s-%s-%s", h.id, p, pids[0]) + stream := newMockStream(streamID, pids[0], h.id, p) + h.streams[streamID] = stream + + return stream, nil +} + +func (h *mockHost) Close() error { + return nil +} + +func (h *mockHost) ConnManager() connmgr.ConnManager { + return nil +} + +func (h *mockHost) EventBus() event.Bus { + return nil +} + +// mockStream implements a mock network stream for testing. +type mockStream struct { + mu sync.RWMutex + id string + protocol protocol.ID + localPeer peer.ID + remotePeer peer.ID + readBuffer []byte + writeBuffer []byte + readClosed bool + writeClosed bool + resetErr error + closeErr error + deadline time.Time + readDeadline time.Time + writeDeadline time.Time + stat network.Stats + connectedStream *mockStream + readChan chan []byte // Channel for blocking reads +} + +func newMockStream(id string, proto protocol.ID, local, remote peer.ID) *mockStream { + return &mockStream{ + id: id, + protocol: proto, + localPeer: local, + remotePeer: remote, + readChan: make(chan []byte, 100), + } +} + +func (s *mockStream) Read(p []byte) (int, error) { + // For integration tests, we need to simulate blocking behavior + // Try multiple times with small delays to allow the handler to write + for i := 0; i < 100; i++ { + s.mu.Lock() + + if s.readClosed { + s.mu.Unlock() + + return 0, io.EOF + } + + if s.resetErr != nil { + err := s.resetErr + s.mu.Unlock() + + return 0, err + } + + if len(s.readBuffer) > 0 { + n := copy(p, s.readBuffer) + s.readBuffer = s.readBuffer[n:] + s.mu.Unlock() + + return n, nil + } + + s.mu.Unlock() + + // If no data yet and stream is not closed, wait a bit + time.Sleep(10 * time.Millisecond) + } + + // After timeout, return EOF + return 0, io.EOF +} + +func (s *mockStream) Write(p []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.writeClosed { + return 0, errors.New("write on closed stream") + } + + if s.resetErr != nil { + return 0, s.resetErr + } + + s.writeBuffer = append(s.writeBuffer, p...) + + // If this stream has a connected peer, write to their read buffer + if s.connectedStream != nil { + s.connectedStream.mu.Lock() + s.connectedStream.readBuffer = append(s.connectedStream.readBuffer, p...) + s.connectedStream.mu.Unlock() + } + + return len(p), nil +} + +func (s *mockStream) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closeErr != nil { + return s.closeErr + } + + s.readClosed = true + s.writeClosed = true + + return nil +} + +func (s *mockStream) CloseRead() error { + s.mu.Lock() + defer s.mu.Unlock() + + s.readClosed = true + + return nil +} + +func (s *mockStream) CloseWrite() error { + s.mu.Lock() + defer s.mu.Unlock() + + s.writeClosed = true + + return nil +} + +func (s *mockStream) Reset() error { + s.mu.Lock() + defer s.mu.Unlock() + + s.resetErr = network.ErrReset + s.readClosed = true + s.writeClosed = true + + return nil +} + +func (s *mockStream) ResetWithError(errCode network.StreamErrorCode) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Convert error code to error + s.resetErr = fmt.Errorf("stream reset with error code: %d", errCode) + s.readClosed = true + s.writeClosed = true + + return nil +} + +func (s *mockStream) SetDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.deadline = t + s.readDeadline = t + s.writeDeadline = t + + return nil +} + +func (s *mockStream) SetReadDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.readDeadline = t + + return nil +} + +func (s *mockStream) SetWriteDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.writeDeadline = t + + return nil +} + +func (s *mockStream) ID() string { + return s.id +} + +func (s *mockStream) Protocol() protocol.ID { + return s.protocol +} + +func (s *mockStream) SetProtocol(id protocol.ID) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.protocol = id + + return nil +} + +func (s *mockStream) Stat() network.Stats { + return s.stat +} + +func (s *mockStream) Conn() network.Conn { + return &mockConn{ + localPeer: s.localPeer, + remotePeer: s.remotePeer, + } +} + +func (s *mockStream) Scope() network.StreamScope { + return nil +} + +// Helper methods for testing. +func (s *mockStream) setReadData(data []byte) { + s.mu.Lock() + defer s.mu.Unlock() + + s.readBuffer = data +} + +func (s *mockStream) getWrittenData() []byte { + s.mu.RLock() + defer s.mu.RUnlock() + + data := make([]byte, len(s.writeBuffer)) + copy(data, s.writeBuffer) + + return data +} + +// mockEncoder implements a simple encoder for testing. +type mockEncoder struct { + encodeFunc func(msg any) ([]byte, error) + decodeFunc func(data []byte, msgType any) error +} + +func (e *mockEncoder) Encode(msg any) ([]byte, error) { + if e.encodeFunc != nil { + return e.encodeFunc(msg) + } + + // Simple string encoding for testing + if str, ok := msg.(string); ok { + return []byte(str), nil + } + + // Handle string pointer + if strPtr, ok := msg.(*string); ok { + return []byte(*strPtr), nil + } + + return nil, errors.New("unsupported type") +} + +func (e *mockEncoder) Decode(data []byte, msgType any) error { + if e.decodeFunc != nil { + return e.decodeFunc(data, msgType) + } + + // Simple string decoding for testing + if ptr, ok := msgType.(*string); ok { + *ptr = string(data) + + return nil + } + + return errors.New("unsupported type") +} + +// mockCompressor implements a simple compressor for testing. +type mockCompressor struct { + compressFunc func(data []byte) ([]byte, error) + decompressFunc func(data []byte) ([]byte, error) +} + +func (c *mockCompressor) Compress(data []byte) ([]byte, error) { + if c.compressFunc != nil { + return c.compressFunc(data) + } + + // Simple prefix compression for testing + return append([]byte("COMPRESSED:"), data...), nil +} + +func (c *mockCompressor) Decompress(data []byte) ([]byte, error) { + if c.decompressFunc != nil { + return c.decompressFunc(data) + } + + // Simple prefix decompression for testing + prefix := []byte("COMPRESSED:") + if len(data) < len(prefix) { + return nil, errors.New("invalid compressed data") + } + + return data[len(prefix):], nil +} + +// mockStreamHandler implements StreamHandler for testing. +type mockStreamHandler struct { + handleFunc func(ctx context.Context, stream network.Stream) +} + +func (h *mockStreamHandler) HandleStream(ctx context.Context, stream network.Stream) { + if h.handleFunc != nil { + h.handleFunc(ctx, stream) + } +} + +// Test protocol implementations. +type testProtocol struct { + id protocol.ID + maxRequestSize uint64 + maxResponseSize uint64 +} + +func (p testProtocol) ID() protocol.ID { + return p.id +} + +func (p testProtocol) MaxRequestSize() uint64 { + return p.maxRequestSize +} + +func (p testProtocol) MaxResponseSize() uint64 { + return p.maxResponseSize +} + +// Chunked test protocol. +type testChunkedProtocol struct { + testProtocol + chunked bool +} + +func (p testChunkedProtocol) IsChunked() bool { + return p.chunked +} + +// Test request/response types. +type testRequest struct { + Message string + ID int +} + +type testResponse struct { + Message string + ID int + Time time.Time +} + +// Helper function to create a connected pair of mock hosts. +func createConnectedMockHosts() (*mockHost, *mockHost) { + host1 := newMockHost("peer1") + host2 := newMockHost("peer2") + + // Set up host1 to create streams that connect to host2's handlers + host1.newStreamFunc = func(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + if len(pids) == 0 { + return nil, errors.New("no protocol specified") + } + + // Create client stream + clientStreamID := fmt.Sprintf("%s-%s-%s-client", host1.id, p, pids[0]) + clientStream := newMockStream(clientStreamID, pids[0], host1.id, p) + + // Create server stream + serverStreamID := fmt.Sprintf("%s-%s-%s-server", p, host1.id, pids[0]) + serverStream := newMockStream(serverStreamID, pids[0], p, host1.id) + + // Connect the streams bidirectionally + clientStream.connectedStream = serverStream + serverStream.connectedStream = clientStream + + // Find handler in host2 + host2.mu.RLock() + handler, ok := host2.handlers[pids[0]] + host2.mu.RUnlock() + + if ok && handler != nil { + // Create bidirectional stream for the server + bidiStream := &bidirectionalMockStream{ + clientStream: clientStream, + serverStream: serverStream, + } + + // Simulate the remote side handling the stream + // We run this in a goroutine to mimic real network behavior + go func() { + // Call the handler with the stream + handler(bidiStream) + }() + } + + return clientStream, nil + } + + return host1, host2 +} + +// mockConn implements a mock network connection for testing. +type mockConn struct { + localPeer peer.ID + remotePeer peer.ID +} + +func (c *mockConn) Close() error { + return nil +} + +func (c *mockConn) CloseWithError(errCode network.ConnErrorCode) error { + return nil +} + +func (c *mockConn) IsClosed() bool { + return false +} + +func (c *mockConn) ID() string { + return fmt.Sprintf("%s-%s", c.localPeer, c.remotePeer) +} + +func (c *mockConn) NewStream(context.Context) (network.Stream, error) { + return nil, errors.New("not implemented") +} + +func (c *mockConn) GetStreams() []network.Stream { + return nil +} + +func (c *mockConn) Stat() network.ConnStats { + return network.ConnStats{} +} + +func (c *mockConn) Scope() network.ConnScope { + return nil +} + +func (c *mockConn) LocalPeer() peer.ID { + return c.localPeer +} + +func (c *mockConn) RemotePeer() peer.ID { + return c.remotePeer +} + +func (c *mockConn) RemotePublicKey() ic.PubKey { + return nil +} + +func (c *mockConn) ConnState() network.ConnectionState { + return network.ConnectionState{} +} + +func (c *mockConn) LocalMultiaddr() ma.Multiaddr { + return nil +} + +func (c *mockConn) RemoteMultiaddr() ma.Multiaddr { + return nil +} + +// bidirectionalMockStream simulates a bidirectional stream. +type bidirectionalMockStream struct { + clientStream *mockStream + serverStream *mockStream +} + +func (s *bidirectionalMockStream) Write(p []byte) (int, error) { + // Server writes to client's read buffer + s.clientStream.mu.Lock() + s.clientStream.readBuffer = append(s.clientStream.readBuffer, p...) + s.clientStream.mu.Unlock() + + return len(p), nil +} + +func (s *bidirectionalMockStream) Read(p []byte) (int, error) { + // Server reads from its own read buffer (what client wrote to serverStream via connectedStream) + s.serverStream.mu.Lock() + defer s.serverStream.mu.Unlock() + + if s.serverStream.readClosed { + return 0, io.EOF + } + + if s.serverStream.resetErr != nil { + return 0, s.serverStream.resetErr + } + + if len(s.serverStream.readBuffer) == 0 { + return 0, io.EOF + } + + n := copy(p, s.serverStream.readBuffer) + s.serverStream.readBuffer = s.serverStream.readBuffer[n:] + + return n, nil +} + +func (s *bidirectionalMockStream) Close() error { + return s.serverStream.Close() +} + +func (s *bidirectionalMockStream) CloseRead() error { + return s.serverStream.CloseRead() +} + +func (s *bidirectionalMockStream) CloseWrite() error { + return s.serverStream.CloseWrite() +} + +func (s *bidirectionalMockStream) Reset() error { + return s.serverStream.Reset() +} + +func (s *bidirectionalMockStream) ResetWithError(errCode network.StreamErrorCode) error { + return s.serverStream.ResetWithError(errCode) +} + +func (s *bidirectionalMockStream) SetDeadline(t time.Time) error { + return s.serverStream.SetDeadline(t) +} + +func (s *bidirectionalMockStream) SetReadDeadline(t time.Time) error { + return s.serverStream.SetReadDeadline(t) +} + +func (s *bidirectionalMockStream) SetWriteDeadline(t time.Time) error { + return s.serverStream.SetWriteDeadline(t) +} + +func (s *bidirectionalMockStream) ID() string { + return s.serverStream.ID() +} + +func (s *bidirectionalMockStream) Protocol() protocol.ID { + return s.serverStream.Protocol() +} + +func (s *bidirectionalMockStream) SetProtocol(id protocol.ID) error { + return s.serverStream.SetProtocol(id) +} + +func (s *bidirectionalMockStream) Stat() network.Stats { + return s.serverStream.Stat() +} + +func (s *bidirectionalMockStream) Conn() network.Conn { + return s.serverStream.Conn() +} + +func (s *bidirectionalMockStream) Scope() network.StreamScope { + return s.serverStream.Scope() +} diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/reqresp.go b/pkg/consensus/mimicry/p2p/reqresp/v1/reqresp.go index b7c4a3b..5046229 100644 --- a/pkg/consensus/mimicry/p2p/reqresp/v1/reqresp.go +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/reqresp.go @@ -65,7 +65,7 @@ func (r *ReqResp) Stop() error { defer r.mu.Unlock() if !r.started { - return fmt.Errorf("service not started") + return nil } // Remove all handlers from the host diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/reqresp_test.go b/pkg/consensus/mimicry/p2p/reqresp/v1/reqresp_test.go new file mode 100644 index 0000000..2d083e1 --- /dev/null +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/reqresp_test.go @@ -0,0 +1,459 @@ +package v1 + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + host := newMockHost("test-peer") + config := DefaultServiceConfig() + logger := logrus.New() + + service := New(host, config, logger) + require.NotNil(t, service) + + // Verify it implements Service interface + var _ Service = service +} + +func TestService_StartStop(t *testing.T) { + host := newMockHost("test-peer") + config := DefaultServiceConfig() + logger := logrus.New() + + service := New(host, config, logger) + + // Start the service + ctx := context.Background() + err := service.Start(ctx) + require.NoError(t, err) + + // Start again should return error + err = service.Start(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "already started") + + // Stop the service + err = service.Stop() + require.NoError(t, err) + + // Stop again should be safe + err = service.Stop() + require.NoError(t, err) +} + +func TestService_RegisterUnregister(t *testing.T) { + host := newMockHost("test-peer") + config := DefaultServiceConfig() + logger := logrus.New() + + service := New(host, config, logger) + + // Start the service + ctx := context.Background() + err := service.Start(ctx) + require.NoError(t, err) + defer func() { _ = service.Stop() }() + + protocolID := protocol.ID("/test/1.0.0") + handler := &mockStreamHandler{ + handleFunc: func(ctx context.Context, stream network.Stream) { + // Do nothing + }, + } + + // Register handler + err = service.Register(protocolID, handler) + require.NoError(t, err) + + // Register again should return error + err = service.Register(protocolID, handler) + require.Error(t, err) + assert.Equal(t, ErrHandlerExists, err) + + // Unregister handler + err = service.Unregister(protocolID) + require.NoError(t, err) + + // Unregister again should return error + err = service.Unregister(protocolID) + require.Error(t, err) + assert.Equal(t, ErrNoHandler, err) + + // Can register again after unregister + err = service.Register(protocolID, handler) + require.NoError(t, err) +} + +func TestService_RegisterBeforeStart(t *testing.T) { + host := newMockHost("test-peer") + config := DefaultServiceConfig() + logger := logrus.New() + + service := New(host, config, logger) + + protocolID := protocol.ID("/test/1.0.0") + handler := &mockStreamHandler{} + + // Register should work before start (handlers are queued) + err := service.Register(protocolID, handler) + require.NoError(t, err) + + // Verify handler is registered + err = service.Unregister(protocolID) + require.NoError(t, err) +} + +func TestService_SendRequestAfterStop(t *testing.T) { + host := newMockHost("test-peer") + config := DefaultServiceConfig() + logger := logrus.New() + + service := New(host, config, logger) + + // Start and stop the service + ctx := context.Background() + err := service.Start(ctx) + require.NoError(t, err) + err = service.Stop() + require.NoError(t, err) + + // Send request should fail + var req, resp string + err = service.SendRequest(ctx, "peer123", "/test/1.0.0", &req, &resp) + require.Error(t, err) + assert.Equal(t, ErrServiceStopped, err) +} + +func TestService_ConcurrentOperations(t *testing.T) { + host := newMockHost("test-peer") + config := DefaultServiceConfig() + logger := logrus.New() + + service := New(host, config, logger) + + ctx := context.Background() + err := service.Start(ctx) + require.NoError(t, err) + defer func() { _ = service.Stop() }() + + // Run concurrent operations + var wg sync.WaitGroup + errors := make(chan error, 100) + + // Concurrent registrations + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + protocolID := protocol.ID(fmt.Sprintf("/test/%d/1.0.0", idx)) + handler := &mockStreamHandler{} + if err := service.Register(protocolID, handler); err != nil { + errors <- err + } + }(i) + } + + // Concurrent unregistrations + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + // Wait a bit to let registrations happen + time.Sleep(10 * time.Millisecond) + protocolID := protocol.ID(fmt.Sprintf("/test/%d/1.0.0", idx)) + if err := service.Unregister(protocolID); err != nil && err != ErrNoHandler { + errors <- err + } + }(i) + } + + // Wait for all operations to complete + wg.Wait() + close(errors) + + // Check for errors + for err := range errors { + t.Errorf("Concurrent operation error: %v", err) + } +} + +func TestRegisterProtocol(t *testing.T) { + host := newMockHost("test-peer") + config := DefaultServiceConfig() + logger := logrus.New() + + service := New(host, config, logger) + + ctx := context.Background() + err := service.Start(ctx) + require.NoError(t, err) + defer func() { _ = service.Stop() }() + + proto := testProtocol{ + id: "/test/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + } + + handler := func(ctx context.Context, req testRequest, from peer.ID) (testResponse, error) { + return testResponse{Message: "pong", ID: req.ID}, nil + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{}, + RequestTimeout: 10 * time.Second, + } + + // Register the protocol + err = RegisterProtocol(service, proto, handler, opts) + require.NoError(t, err) + + // Verify handler was registered + err = service.Unregister(proto.ID()) + require.NoError(t, err) +} + +func TestRegisterChunkedProtocol(t *testing.T) { + host := newMockHost("test-peer") + config := DefaultServiceConfig() + logger := logrus.New() + + service := New(host, config, logger) + + ctx := context.Background() + err := service.Start(ctx) + require.NoError(t, err) + defer func() { _ = service.Stop() }() + + proto := testChunkedProtocol{ + testProtocol: testProtocol{ + id: "/test/chunked/1.0.0", + maxRequestSize: 1024, + maxResponseSize: 2048, + }, + chunked: true, + } + + handler := func(ctx context.Context, req testRequest, from peer.ID, writer ChunkedResponseWriter[testResponse]) error { + return writer.WriteChunk(testResponse{Message: "chunk", ID: req.ID}) + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{}, + RequestTimeout: 10 * time.Second, + } + + // Register the chunked protocol + err = RegisterChunkedProtocol(service, proto, handler, opts) + require.NoError(t, err) + + // Verify handler was registered + err = service.Unregister(proto.ID()) + require.NoError(t, err) +} + +func TestNewProtocol(t *testing.T) { + proto := NewProtocol("/test/1.0.0", 1024, 2048) + + assert.Equal(t, protocol.ID("/test/1.0.0"), proto.ID()) + assert.Equal(t, uint64(1024), proto.MaxRequestSize()) + assert.Equal(t, uint64(2048), proto.MaxResponseSize()) +} + +func TestNewChunkedProtocol(t *testing.T) { + proto := NewChunkedProtocol("/test/chunked/1.0.0", 1024, 2048) + + assert.Equal(t, protocol.ID("/test/chunked/1.0.0"), proto.ID()) + assert.Equal(t, uint64(1024), proto.MaxRequestSize()) + assert.Equal(t, uint64(2048), proto.MaxResponseSize()) + assert.True(t, proto.IsChunked()) +} + +func TestService_IntegrationScenario(t *testing.T) { + // Create two mock hosts that can communicate + host1, host2 := createConnectedMockHosts() + + config := DefaultServiceConfig() + logger := logrus.New() + + // Create services for both hosts + service1 := New(host1, config, logger) + service2 := New(host2, config, logger) + + ctx := context.Background() + + // Start both services + err := service1.Start(ctx) + require.NoError(t, err) + defer func() { _ = service1.Stop() }() + + err = service2.Start(ctx) + require.NoError(t, err) + defer func() { _ = service2.Stop() }() + + // Register a handler on service2 + proto := NewProtocol("/echo/1.0.0", 1024, 1024) + echoHandler := func(ctx context.Context, req string, from peer.ID) (string, error) { + return "Echo: " + req, nil + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if str, ok := msg.(string); ok { + return []byte(str), nil + } + if strPtr, ok := msg.(*string); ok { + return []byte(*strPtr), nil + } + + return nil, errors.New("unsupported type") + }, + decodeFunc: func(data []byte, msgType any) error { + if ptr, ok := msgType.(*string); ok { + *ptr = string(data) + + return nil + } + + return errors.New("unsupported type") + }, + }, + RequestTimeout: 5 * time.Second, + } + + err = RegisterProtocol(service2, proto, echoHandler, opts) + require.NoError(t, err) + + // Send request from service1 to service2 + req := "Hello, World!" + var resp string + + reqOpts := RequestOptions{ + Encoder: opts.Encoder, + Timeout: 5 * time.Second, + } + + err = service1.SendRequestWithOptions(ctx, host2.ID(), proto.ID(), &req, &resp, reqOpts) + require.NoError(t, err) + assert.Equal(t, "Echo: Hello, World!", resp) +} + +func TestService_ChunkedIntegrationScenario(t *testing.T) { + // Create two mock hosts that can communicate + host1, host2 := createConnectedMockHosts() + + config := DefaultServiceConfig() + logger := logrus.New() + + // Create services for both hosts + service1 := New(host1, config, logger) + service2 := New(host2, config, logger) + + ctx := context.Background() + + // Start both services + err := service1.Start(ctx) + require.NoError(t, err) + defer func() { _ = service1.Stop() }() + + err = service2.Start(ctx) + require.NoError(t, err) + defer func() { _ = service2.Stop() }() + + // Register a chunked handler on service2 + proto := NewChunkedProtocol("/blocks/1.0.0", 1024, 1024) + blocksHandler := func(ctx context.Context, req int, from peer.ID, writer ChunkedResponseWriter[string]) error { + // Send multiple chunks + for i := 0; i < req; i++ { + if writeErr := writer.WriteChunk(fmt.Sprintf("Block %d", i)); writeErr != nil { + return writeErr + } + } + + return nil + } + + opts := HandlerOptions{ + Encoder: &mockEncoder{ + encodeFunc: func(msg any) ([]byte, error) { + if n, ok := msg.(int); ok { + return []byte(fmt.Sprintf("%d", n)), nil + } + if nPtr, ok := msg.(*int); ok { + return []byte(fmt.Sprintf("%d", *nPtr)), nil + } + if str, ok := msg.(string); ok { + return []byte(str), nil + } + if strPtr, ok := msg.(*string); ok { + return []byte(*strPtr), nil + } + + return nil, errors.New("unsupported type") + }, + decodeFunc: func(data []byte, msgType any) error { + if ptr, ok := msgType.(*int); ok { + _, scanErr := fmt.Sscanf(string(data), "%d", ptr) + + return scanErr + } + if ptr, ok := msgType.(*string); ok { + *ptr = string(data) + + return nil + } + + return errors.New("unsupported type") + }, + }, + RequestTimeout: 5 * time.Second, + } + + err = RegisterChunkedProtocol(service2, proto, blocksHandler, opts) + require.NoError(t, err) + + // Send chunked request from service1 to service2 + req := 3 // Request 3 blocks + receivedChunks := []string{} + + chunkHandler := func(chunk any) error { + if data, ok := chunk.([]byte); ok { + var str string + if decodeErr := opts.Encoder.Decode(data, &str); decodeErr == nil { + receivedChunks = append(receivedChunks, str) + } + } + + return nil + } + + chunkedClient := NewChunkedClient(host1, config.ClientConfig, logger) + + reqOpts := RequestOptions{ + Encoder: opts.Encoder, + Timeout: 5 * time.Second, + } + + sendErr := chunkedClient.SendChunkedRequestWithOptions(ctx, host2.ID(), proto.ID(), &req, chunkHandler, reqOpts) + require.NoError(t, sendErr) + + // Verify we received the expected chunks + assert.Equal(t, 3, len(receivedChunks)) + assert.Equal(t, "Block 0", receivedChunks[0]) + assert.Equal(t, "Block 1", receivedChunks[1]) + assert.Equal(t, "Block 2", receivedChunks[2]) +} diff --git a/pkg/consensus/mimicry/p2p/reqresp/v1/types_test.go b/pkg/consensus/mimicry/p2p/reqresp/v1/types_test.go new file mode 100644 index 0000000..cf01c38 --- /dev/null +++ b/pkg/consensus/mimicry/p2p/reqresp/v1/types_test.go @@ -0,0 +1,267 @@ +package v1 + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStatus_String(t *testing.T) { + tests := []struct { + name string + status Status + expected string + }{ + { + name: "success", + status: StatusSuccess, + expected: "success", + }, + { + name: "invalid_request", + status: StatusInvalidRequest, + expected: "invalid_request", + }, + { + name: "server_error", + status: StatusServerError, + expected: "server_error", + }, + { + name: "resource_unavailable", + status: StatusResourceUnavailable, + expected: "resource_unavailable", + }, + { + name: "rate_limited", + status: StatusRateLimited, + expected: "rate_limited", + }, + { + name: "unknown_status", + status: Status(99), + expected: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.status.String() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestStatus_IsError(t *testing.T) { + tests := []struct { + name string + status Status + expected bool + }{ + { + name: "success_is_not_error", + status: StatusSuccess, + expected: false, + }, + { + name: "invalid_request_is_error", + status: StatusInvalidRequest, + expected: true, + }, + { + name: "server_error_is_error", + status: StatusServerError, + expected: true, + }, + { + name: "resource_unavailable_is_error", + status: StatusResourceUnavailable, + expected: true, + }, + { + name: "rate_limited_is_error", + status: StatusRateLimited, + expected: true, + }, + { + name: "unknown_status_is_error", + status: Status(99), + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.status.IsError() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestDefaultServiceConfig(t *testing.T) { + config := DefaultServiceConfig() + + // Check HandlerOptions + assert.Equal(t, 30*time.Second, config.HandlerOptions.RequestTimeout) + assert.False(t, config.HandlerOptions.EnableMetrics) + assert.Nil(t, config.HandlerOptions.Encoder) + assert.Nil(t, config.HandlerOptions.Compressor) + + // Check ClientConfig + assert.Equal(t, 30*time.Second, config.ClientConfig.DefaultTimeout) + assert.Equal(t, 3, config.ClientConfig.MaxRetries) + assert.Equal(t, 1*time.Second, config.ClientConfig.RetryDelay) + assert.False(t, config.ClientConfig.EnableMetrics) +} + +func TestProtocolConfig(t *testing.T) { + config := ProtocolConfig{ + ID: "/test/1.0.0", + Version: "1.0.0", + MaxRequestSize: 1024, + MaxResponseSize: 2048, + Timeout: 5 * time.Second, + } + + assert.Equal(t, "/test/1.0.0", string(config.ID)) + assert.Equal(t, "1.0.0", config.Version) + assert.Equal(t, uint64(1024), config.MaxRequestSize) + assert.Equal(t, uint64(2048), config.MaxResponseSize) + assert.Equal(t, 5*time.Second, config.Timeout) +} + +func TestRequestMetadata(t *testing.T) { + now := time.Now() + meta := RequestMetadata{ + Protocol: "/test/1.0.0", + PeerID: "peer123", + RequestedAt: now, + Size: 256, + } + + assert.Equal(t, "/test/1.0.0", string(meta.Protocol)) + assert.Equal(t, "peer123", meta.PeerID) + assert.Equal(t, now, meta.RequestedAt) + assert.Equal(t, 256, meta.Size) +} + +func TestResponseMetadata(t *testing.T) { + now := time.Now() + meta := ResponseMetadata{ + Protocol: "/test/1.0.0", + PeerID: "peer123", + Status: StatusSuccess, + RespondedAt: now, + Size: 512, + Duration: 100 * time.Millisecond, + } + + assert.Equal(t, "/test/1.0.0", string(meta.Protocol)) + assert.Equal(t, "peer123", meta.PeerID) + assert.Equal(t, StatusSuccess, meta.Status) + assert.Equal(t, now, meta.RespondedAt) + assert.Equal(t, 512, meta.Size) + assert.Equal(t, 100*time.Millisecond, meta.Duration) +} + +func TestHandlerOptions(t *testing.T) { + encoder := &mockEncoder{} + compressor := &mockCompressor{} + + opts := HandlerOptions{ + Encoder: encoder, + Compressor: compressor, + RequestTimeout: 10 * time.Second, + EnableMetrics: true, + } + + assert.Equal(t, encoder, opts.Encoder) + assert.Equal(t, compressor, opts.Compressor) + assert.Equal(t, 10*time.Second, opts.RequestTimeout) + assert.True(t, opts.EnableMetrics) +} + +func TestClientConfig(t *testing.T) { + config := ClientConfig{ + DefaultTimeout: 20 * time.Second, + MaxRetries: 5, + RetryDelay: 2 * time.Second, + EnableMetrics: true, + } + + assert.Equal(t, 20*time.Second, config.DefaultTimeout) + assert.Equal(t, 5, config.MaxRetries) + assert.Equal(t, 2*time.Second, config.RetryDelay) + assert.True(t, config.EnableMetrics) +} + +func TestRequestOptions(t *testing.T) { + encoder := &mockEncoder{} + compressor := &mockCompressor{} + + opts := RequestOptions{ + Encoder: encoder, + Compressor: compressor, + Timeout: 15 * time.Second, + } + + assert.Equal(t, encoder, opts.Encoder) + assert.Equal(t, compressor, opts.Compressor) + assert.Equal(t, 15*time.Second, opts.Timeout) +} + +func TestServiceConfig(t *testing.T) { + encoder := &mockEncoder{} + compressor := &mockCompressor{} + + config := ServiceConfig{ + HandlerOptions: HandlerOptions{ + Encoder: encoder, + Compressor: compressor, + RequestTimeout: 30 * time.Second, + EnableMetrics: true, + }, + ClientConfig: ClientConfig{ + DefaultTimeout: 30 * time.Second, + MaxRetries: 3, + RetryDelay: 1 * time.Second, + EnableMetrics: true, + }, + } + + // Verify HandlerOptions + assert.Equal(t, encoder, config.HandlerOptions.Encoder) + assert.Equal(t, compressor, config.HandlerOptions.Compressor) + assert.Equal(t, 30*time.Second, config.HandlerOptions.RequestTimeout) + assert.True(t, config.HandlerOptions.EnableMetrics) + + // Verify ClientConfig + assert.Equal(t, 30*time.Second, config.ClientConfig.DefaultTimeout) + assert.Equal(t, 3, config.ClientConfig.MaxRetries) + assert.Equal(t, 1*time.Second, config.ClientConfig.RetryDelay) + assert.True(t, config.ClientConfig.EnableMetrics) +} + +func TestErrorConstants(t *testing.T) { + // Test that error constants are not nil + require.NotNil(t, ErrInvalidRequest) + require.NotNil(t, ErrInvalidResponse) + require.NotNil(t, ErrStreamReset) + require.NotNil(t, ErrTimeout) + require.NotNil(t, ErrNoHandler) + require.NotNil(t, ErrHandlerExists) + require.NotNil(t, ErrServiceStopped) + require.NotNil(t, ErrMaxSizeExceeded) + + // Test error messages + assert.Contains(t, ErrInvalidRequest.Error(), "invalid request") + assert.Contains(t, ErrInvalidResponse.Error(), "invalid response") + assert.Contains(t, ErrStreamReset.Error(), "stream reset") + assert.Contains(t, ErrTimeout.Error(), "timed out") + assert.Contains(t, ErrNoHandler.Error(), "no handler") + assert.Contains(t, ErrHandlerExists.Error(), "handler already registered") + assert.Contains(t, ErrServiceStopped.Error(), "service stopped") + assert.Contains(t, ErrMaxSizeExceeded.Error(), "max size exceeded") +}