Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions audit/multi_auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"log/slog"
"os"

"github.com/google/uuid"
)

// MultiAuditor wraps multiple auditors and sends audit events to all of them.
Expand All @@ -28,7 +30,7 @@ func (m *MultiAuditor) AuditRequest(req Request) {
// provided configuration. It always includes a LogAuditor for stderr logging,
// and conditionally adds a SocketAuditor if audit logs are enabled and the
// workspace agent's log proxy socket exists.
func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string) (Auditor, error) {
func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string, sessionID uuid.UUID) (Auditor, error) {
stderrAuditor := NewLogAuditor(logger)
auditors := []Auditor{stderrAuditor}

Expand All @@ -48,7 +50,8 @@ func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs boo
}
agentWillProxy := !os.IsNotExist(err)
if agentWillProxy {
socketAuditor := NewSocketAuditor(logger, logProxySocketPath)
seq := &SequenceCounter{}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't currently included in the log auditor. Only the socket auditor. Does that matter? I don't think so, but I'm open to the idea.

socketAuditor := NewSocketAuditor(logger, logProxySocketPath, sessionID, seq)
go socketAuditor.Loop(ctx)
auditors = append(auditors, socketAuditor)
} else {
Expand Down
10 changes: 6 additions & 4 deletions audit/multi_auditor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"os"
"path/filepath"
"testing"

"github.com/google/uuid"
)

type mockAuditor struct {
Expand All @@ -25,7 +27,7 @@ func TestSetupAuditor_DisabledAuditLogs(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()

auditor, err := SetupAuditor(ctx, logger, true, "")
auditor, err := SetupAuditor(ctx, logger, true, "", uuid.New())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand All @@ -50,7 +52,7 @@ func TestSetupAuditor_EmptySocketPath(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()

_, err := SetupAuditor(ctx, logger, false, "")
_, err := SetupAuditor(ctx, logger, false, "", uuid.New())
if err == nil {
t.Fatal("expected error for empty socket path, got nil")
}
Expand All @@ -62,7 +64,7 @@ func TestSetupAuditor_SocketDoesNotExist(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()

auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path")
auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path", uuid.New())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -100,7 +102,7 @@ func TestSetupAuditor_SocketExists(t *testing.T) {
t.Fatalf("failed to close temp file: %v", err)
}

auditor, err := SetupAuditor(ctx, logger, false, socketPath)
auditor, err := SetupAuditor(ctx, logger, false, socketPath, uuid.New())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down
20 changes: 20 additions & 0 deletions audit/sequence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package audit

import "sync/atomic"

// SequenceCounter is a monotonically increasing counter that assigns a
// unique sequence number to every audit event within a single boundary
// session. The counter starts at 0 and is safe for concurrent use by
// both the socket auditor and the proxy.
type SequenceCounter struct {
next atomic.Int32
}

// Next returns the next sequence number. The first call returns 0,
// subsequent calls return 1, 2, 3, etc. It is safe for concurrent
// use.
func (c *SequenceCounter) Next() int32 {
// Add returns the new value after incrementing, so subtract 1
// to produce a zero-based sequence.
return c.next.Add(1) - 1
}
25 changes: 25 additions & 0 deletions audit/sequence_test.go
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is basically testing atomic.Int32. The only real valid test I could see is testing that it starts at zero and increments.

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package audit

import (
"testing"
)

func TestSequenceCounter_StartsAtZero(t *testing.T) {
t.Parallel()

var c SequenceCounter
if got := c.Next(); got != 0 {
t.Fatalf("first call: got %d, want 0", got)
}
}

func TestSequenceCounter_Increments(t *testing.T) {
t.Parallel()

var c SequenceCounter
for i := range int32(100) {
if got := c.Next(); got != i {
t.Fatalf("call %d: got %d, want %d", i, got, i)
}
}
}
23 changes: 16 additions & 7 deletions audit/socket_auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync/atomic"
"time"

"github.com/google/uuid"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
Expand All @@ -34,6 +35,8 @@ type SocketAuditor struct {
batchSize int
batchTimerDuration time.Duration
socketPath string
sessionID uuid.UUID
seq *SequenceCounter

droppedChannelFull atomic.Int64
droppedBatchFull atomic.Int64
Expand All @@ -45,7 +48,7 @@ type SocketAuditor struct {
// NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's
// boundary log proxy socket after SocketAuditor.Loop is called. The socket path
// is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath.
func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor {
func NewSocketAuditor(logger *slog.Logger, socketPath string, sessionID uuid.UUID, seq *SequenceCounter) *SocketAuditor {
// This channel buffer size intends to allow enough buffering for bursty
// AI agent network requests while a batch is being sent to the workspace
// agent.
Expand All @@ -60,6 +63,8 @@ func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor {
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
socketPath: socketPath,
sessionID: sessionID,
seq: seq,
}
}

Expand All @@ -77,9 +82,10 @@ func (s *SocketAuditor) AuditRequest(req Request) {
}

log := &agentproto.BoundaryLog{
Allowed: req.Allowed,
Time: timestamppb.Now(),
Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq},
Allowed: req.Allowed,
Time: timestamppb.Now(),
SequenceNumber: s.seq.Next(),
Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq},
}

select {
Expand All @@ -100,14 +106,17 @@ type flushErr struct {
func (e *flushErr) Error() string { return e.err.Error() }

// flush sends the current batch of logs to the given connection.
func flush(conn net.Conn, logs []*agentproto.BoundaryLog) *flushErr {
func flush(conn net.Conn, sessionID uuid.UUID, logs []*agentproto.BoundaryLog) *flushErr {
if len(logs) == 0 {
return nil
}

msg := &codec.BoundaryMessage{
Msg: &codec.BoundaryMessage_Logs{
Logs: &agentproto.ReportBoundaryLogsRequest{Logs: logs},
Logs: &agentproto.ReportBoundaryLogsRequest{
Logs: logs,
SessionId: sessionID.String(),
},
},
}
if err := codec.WriteMessage(conn, codec.TagV2, msg); err != nil {
Expand Down Expand Up @@ -188,7 +197,7 @@ func (s *SocketAuditor) Loop(ctx context.Context) {
return
}

if err := flush(conn, batch); err != nil {
if err := flush(conn, s.sessionID, batch); err != nil {
if err.permanent {
// Data error: discard batch to avoid infinite retries.
s.logger.Warn("dropping batch due to data error on flush attempt",
Expand Down
81 changes: 79 additions & 2 deletions audit/socket_auditor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"testing"
"time"

"github.com/google/uuid"

"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
agentproto "github.com/coder/coder/v2/agent/proto"
)
Expand All @@ -33,6 +35,15 @@ func TestSocketAuditor_AuditRequest_QueuesLog(t *testing.T) {
if log.Allowed != true {
t.Errorf("expected Allowed=true, got %v", log.Allowed)
}
if log.Time == nil {
t.Fatal("expected Time to be set")
}
Comment on lines +38 to +40
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also validate that it's not the zero time

if log.Time.AsTime().IsZero() {
t.Fatal("expected Time to not be the zero time")
}
if log.SequenceNumber != 0 {
t.Errorf("expected first SequenceNumber=0, got %d", log.SequenceNumber)
}
httpReq := log.GetHttpRequest()
if httpReq == nil {
t.Fatal("expected HttpRequest, got nil")
Expand Down Expand Up @@ -230,6 +241,8 @@ func TestSocketAuditor_Loop_RetriesOnConnectionFailure(t *testing.T) {
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
batchSize: defaultBatchSize,
batchTimerDuration: time.Hour, // Ensure timer doesn't interfere with the test
sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"),
seq: &SequenceCounter{},
}

// Set up hook to detect flush attempts
Expand Down Expand Up @@ -349,6 +362,8 @@ func TestSocketAuditor_Loop_ReportsBatchFullDrops(t *testing.T) {
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
batchSize: defaultBatchSize,
batchTimerDuration: time.Hour,
sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"),
seq: &SequenceCounter{},
}

flushed := make(chan struct{}, 4)
Expand Down Expand Up @@ -451,17 +466,75 @@ func TestSocketAuditor_Loop_ShutdownFlushIncludesDrops(t *testing.T) {
func TestFlush_EmptyBatch(t *testing.T) {
t.Parallel()

err := flush(nil, nil)
err := flush(nil, uuid.Nil, nil)
if err != nil {
t.Errorf("expected nil error for empty batch, got %v", err)
}

err = flush(nil, []*agentproto.BoundaryLog{})
err = flush(nil, uuid.Nil, []*agentproto.BoundaryLog{})
if err != nil {
t.Errorf("expected nil error for empty slice, got %v", err)
}
}

func TestSocketAuditor_AuditRequest_SequenceNumberIncrements(t *testing.T) {
t.Parallel()

auditor := setupSocketAuditor(t)

for i := range 5 {
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})

select {
case log := <-auditor.logCh:
if log.SequenceNumber != int32(i) {
t.Errorf("request %d: expected SequenceNumber=%d, got %d", i, i, log.SequenceNumber)
}
default:
t.Fatalf("request %d: expected log in channel, got none", i)
}
}
}

func TestSocketAuditor_Loop_FlushIncludesSessionID(t *testing.T) {
t.Parallel()

auditor, serverConn := setupTestAuditor(t)
auditor.batchTimerDuration = time.Hour

cr := newConnReader()
go readFromConn(t, serverConn, cr)

go auditor.Loop(t.Context())

// Fill a batch to trigger a flush.
for i := 0; i < auditor.batchSize; i++ {
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
}

select {
case req := <-cr.logs:
expectedSessionID := uuid.MustParse("00000000-0000-4000-8000-000000000001").String()
if req.SessionId != expectedSessionID {
t.Errorf("expected SessionId=%s, got %q", expectedSessionID, req.SessionId)
}
if len(req.Logs) != auditor.batchSize {
t.Errorf("expected %d logs, got %d", auditor.batchSize, len(req.Logs))
}
// Verify sequence numbers are monotonically increasing.
for i, log := range req.Logs {
if log.SequenceNumber != int32(i) {
t.Errorf("log %d: expected SequenceNumber=%d, got %d", i, i, log.SequenceNumber)
}
if log.Time == nil {
t.Errorf("log %d: expected Time to be set", i)
}
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for flush")
}
}

// setupSocketAuditor creates a SocketAuditor for tests that only exercise
// the queueing behavior (no connection needed).
func setupSocketAuditor(t *testing.T) *SocketAuditor {
Expand All @@ -475,6 +548,8 @@ func setupSocketAuditor(t *testing.T) *SocketAuditor {
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"),
seq: &SequenceCounter{},
}
}

Expand Down Expand Up @@ -504,6 +579,8 @@ func setupTestAuditor(t *testing.T) (*SocketAuditor, net.Conn) {
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
sessionID: uuid.MustParse("00000000-0000-4000-8000-000000000001"),
seq: &SequenceCounter{},
}

return auditor, serverConn
Expand Down
6 changes: 6 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/coder/serpent"
"github.com/google/uuid"
"github.com/spf13/pflag"
)

Expand Down Expand Up @@ -85,6 +86,11 @@ type AppConfig struct {
UserInfo *UserInfo
DisableAuditLogs bool
LogProxySocketPath string

// SessionID is a UUIDv4 generated at process startup. It groups
// all audit events produced by this boundary invocation into a
// single session. Set by Run, not by configuration.
SessionID uuid.UUID
}

func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, error) {
Expand Down
Loading
Loading