Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions internal/ipc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func (c *Client) Close() error {
func (c *Client) readLoop() {
defer c.wg.Done()
defer c.conn.Close()
defer c.drainReqs()

scanner := bufio.NewScanner(c.conn)
scanner.Buffer(make([]byte, maxFrameSize), maxFrameSize)
Expand All @@ -80,7 +81,7 @@ func (c *Client) readLoop() {
}

env := &ipcv1.Envelope{}
if err := protojson.Unmarshal(line, env); err != nil {
if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(line, env); err != nil {
if c.closed.Load() {
return
}
Expand All @@ -106,6 +107,17 @@ func (c *Client) readLoop() {
}
}

// drainReqs closes all outstanding request channels so Call() callers unblock
// when the connection drops unexpectedly.
func (c *Client) drainReqs() {
c.mu.Lock()
defer c.mu.Unlock()
for id, ch := range c.reqs {
close(ch)
delete(c.reqs, id)
}
}

// Call sends a request and waits for the response.
func (c *Client) Call(ctx context.Context, req *ipcv1.Request) (*ipcv1.Response, error) {
if c.closed.Load() {
Expand Down Expand Up @@ -154,7 +166,10 @@ func (c *Client) Call(ctx context.Context, req *ipcv1.Request) (*ipcv1.Response,
delete(c.reqs, id)
c.mu.Unlock()
return nil, ctx.Err()
case resp := <-ch:
case resp, ok := <-ch:
if !ok {
return nil, errors.New("ipc: connection closed")
}
if errResp := resp.GetError(); errResp != nil {
return nil, fmt.Errorf("remote error: %s (code %d)", errResp.GetMessage(), errResp.GetCode())
}
Expand Down
202 changes: 202 additions & 0 deletions internal/ipc/extra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"errors"
"fmt"
"net"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -216,6 +217,207 @@ func TestServer_ContextCancelBeforeCall(t *testing.T) {
}
}

// dialUnixWithRetry repeatedly attempts to dial the Unix socket until it
// succeeds or the context is done.
func dialUnixWithRetry(ctx context.Context, sockPath string) (net.Conn, error) {
var lastErr error

for {
select {
case <-ctx.Done():
if lastErr == nil {
lastErr = ctx.Err()
}
return nil, fmt.Errorf("dialUnixWithRetry: %w", lastErr)
default:
}

conn, err := dialUnix(sockPath)
if err == nil {
return conn, nil
}

lastErr = err
time.Sleep(10 * time.Millisecond)
}
}

func TestBroadcastEvent_Delivery(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

sockPath := t.TempDir() + "/broadcast.sock"
srv := NewServer(sockPath, func(_ context.Context, _ *ipcv1.Request) (*ipcv1.Response, error) {
return &ipcv1.Response{Kind: &ipcv1.Response_Status{Status: &ipcv1.StatusResponse{}}}, nil
})
if err := srv.Start(ctx); err != nil {
t.Fatalf("srv.Start: %v", err)
}
defer srv.Stop()

// Connect a raw net.Conn so we can read frames directly.
rawConn, err := dialUnixWithRetry(ctx, sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer rawConn.Close()

Comment thread
darkliquid marked this conversation as resolved.
// Broadcast an event.
srv.BroadcastEvent(&ipcv1.Event{
Kind: &ipcv1.Event_IndexUpdated{IndexUpdated: &ipcv1.IndexUpdatedEvent{}},
})

// Read the event frame.
scanner := bufio.NewScanner(rawConn)
scanner.Buffer(make([]byte, maxFrameSize), maxFrameSize)

done := make(chan struct{})
var gotErr error
var gotEnv *ipcv1.Envelope
go func() {
defer close(done)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
gotErr = err
} else {
gotErr = errors.New("scanner stopped before reading broadcast event frame")
}
return
}
env := &ipcv1.Envelope{}
if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(scanner.Bytes(), env); err != nil {
gotErr = err
return
}
gotEnv = env
}()

select {
case <-done:
case <-ctx.Done():
t.Fatal("timed out waiting for broadcast event")
}

if gotErr != nil {
t.Fatalf("unmarshal broadcast event: %v", gotErr)
}
if gotEnv == nil {
t.Fatal("expected broadcast event envelope, got nil")
}
if gotEnv.GetEvent() == nil {
t.Fatalf("expected event payload, got %T", gotEnv.GetPayload())
}
}

func TestBroadcastEvent_NoInterleave(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Server that blocks inside the handler, letting us race a broadcast against a response write.
started := make(chan struct{})
proceed := make(chan struct{})

sockPath := t.TempDir() + "/nointerleave.sock"
srv := NewServer(sockPath, func(_ context.Context, _ *ipcv1.Request) (*ipcv1.Response, error) {
close(started)
<-proceed
return &ipcv1.Response{Kind: &ipcv1.Response_Status{Status: &ipcv1.StatusResponse{UptimeSeconds: 7}}}, nil
})
if err := srv.Start(ctx); err != nil {
t.Fatalf("srv.Start: %v", err)
}
defer srv.Stop()
time.Sleep(20 * time.Millisecond)

cli, err := NewClient(ctx, sockPath)
if err != nil {
t.Fatalf("NewClient: %v", err)
}
defer cli.Close()

// Start a call (server will block until proceed is closed).
callDone := make(chan error, 1)
go func() {
_, err := cli.Call(ctx, &ipcv1.Request{Kind: &ipcv1.Request_Status{Status: &ipcv1.StatusRequest{}}})
callDone <- err
}()

// Wait for handler to start processing.
select {
case <-started:
case <-ctx.Done():
t.Fatal("timed out waiting for handler to start")
}

// Broadcast while the handler is blocked (so it races with the response write).
srv.BroadcastEvent(&ipcv1.Event{Kind: &ipcv1.Event_IndexUpdated{IndexUpdated: &ipcv1.IndexUpdatedEvent{}}})

// Unblock the handler.
close(proceed)

// Call should succeed without error.
select {
case err := <-callDone:
if err != nil {
t.Errorf("Call after concurrent broadcast: %v", err)
}
case <-ctx.Done():
t.Fatal("timed out waiting for Call to complete")
}
}

func TestClient_ConnectionDrop(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Use a blocking handler so the call is always pending when we drop the conn.
block := make(chan struct{})
started := make(chan struct{})
sockPath := startTestServer(t, func(_ context.Context, _ *ipcv1.Request) (*ipcv1.Response, error) {
// Signal that the server has received the request and is about to block.
close(started)
<-block
return &ipcv1.Response{Kind: &ipcv1.Response_Status{Status: &ipcv1.StatusResponse{}}}, nil
})

cli, err := NewClient(ctx, sockPath)
if err != nil {
t.Fatalf("NewClient: %v", err)
}

// Register a pending call.
callDone := make(chan error, 1)
go func() {
_, callErr := cli.Call(ctx, &ipcv1.Request{Kind: &ipcv1.Request_Status{Status: &ipcv1.StatusRequest{}}})
callDone <- callErr
}()

// Wait deterministically until the server has received the request and is blocking.
select {
case <-started:
// proceed
case <-ctx.Done():
t.Fatalf("timed out waiting for call to start: %v", ctx.Err())
}

// Forcibly close the underlying connection to simulate a drop.
cli.conn.Close()

select {
case err := <-callDone:
if err == nil {
t.Error("expected error when connection drops, got nil")
}
case <-ctx.Done():
t.Fatal("timed out: Call did not unblock after connection drop")
}
}

// dialUnix is a test helper that opens a raw [net.Conn] to a Unix socket.
func dialUnix(path string) (net.Conn, error) {
return net.Dial("unix", path)
}

func TestServer_AllRequestTypes(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down
8 changes: 4 additions & 4 deletions internal/ipc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ func (s *Server) closeActiveConnections() {
s.mu.Lock()
defer s.mu.Unlock()

for cw := range s.conns {
if err := cw.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
for conn := range s.conns {
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
slog.Debug("ipc server close connection", "err", err)
}
}
Expand All @@ -118,7 +118,7 @@ func (s *Server) untrackConn(conn net.Conn) {
delete(s.conns, conn)
}

// BroadcastEvent pushes an event to all connected clients asynchronously.
// BroadcastEvent pushes an event to all connected clients.
func (s *Server) BroadcastEvent(event *ipcv1.Event) {
s.mu.Lock()
writers := make([]*connWriter, 0, len(s.conns))
Expand Down Expand Up @@ -177,7 +177,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, cw *connWriter)
}

env := &ipcv1.Envelope{}
if err := protojson.Unmarshal(line, env); err != nil {
if err := (protojson.UnmarshalOptions{DiscardUnknown: true}).Unmarshal(line, env); err != nil {
if s.closing.Load() {
return
}
Expand Down
Loading