Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
udsamani committed May 24, 2024
1 parent 138ffb4 commit 8afa6ea
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 79 deletions.
5 changes: 2 additions & 3 deletions pkg/nats/proxy/compute_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ type ComputeProxy struct {

func NewComputeProxy(params ComputeProxyParams) (*ComputeProxy, error) {
sc, err := stream.NewConsumerClient(stream.ConsumerClientParams{
Conn: params.Conn,
HeartBeatRequestSub: requesterEndpointPublishSubject(params.Conn.Opts.Name, StreamHeartBeat),
Conn: params.Conn,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -138,7 +137,7 @@ func proxyStreamingRequest[Request any, Response any](
if err != nil {
return nil, fmt.Errorf("%T: failed to marshal request: %w", request.Body, err)
}
res, err := client.OpenStream(ctx, subject, data)
res, err := client.OpenStream(ctx, subject, request.TargetNodeID, data)
if err != nil {
return nil, fmt.Errorf("%T: failed to send request to node %s: %w", request.Body, request.TargetNodeID, err)
}
Expand Down
13 changes: 3 additions & 10 deletions pkg/nats/proxy/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@ package proxy
import "fmt"

const (
ComputeEndpointSubjectPrefix = "node.compute"
RequesterEndpointSubjectPrefix = "node.requester"
CallbackSubjectPrefix = "node.orchestrator"
ManagementSubjectPrefix = "node.management"
ComputeEndpointSubjectPrefix = "node.compute"
CallbackSubjectPrefix = "node.orchestrator"
ManagementSubjectPrefix = "node.management"

AskForBid = "AskForBid/v1"
BidAccepted = "BidAccepted/v1"
BidRejected = "BidRejected/v1"
CancelExecution = "CancelExecution/v1"
ExecutionLogs = "ExecutionLogs/v1"
LogBeatRequest = "LogBeatRequest/v1"
LogBeatResponse = "LogBeatResponse/v1"

OnBidComplete = "OnBidComplete/v1"
OnRunComplete = "OnRunComplete/v1"
Expand All @@ -32,10 +29,6 @@ func computeEndpointPublishSubject(nodeID string, method string) string {
return fmt.Sprintf("%s.%s.%s", ComputeEndpointSubjectPrefix, nodeID, method)
}

func requesterEndpointPublishSubject(nodeID string, method string) string {
return fmt.Sprintf("%s.%s.%s", RequesterEndpointSubjectPrefix, nodeID, method)
}

func computeEndpointSubscribeSubject(nodeID string) string {
return fmt.Sprintf("%s.%s.>", ComputeEndpointSubjectPrefix, nodeID)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/nats/stream/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (suite *BaseTestSuite) newTestStream() *testStream {
})
suite.Require().NoError(err)

ch, err := suite.streamingClient.OpenStream(suite.ctx, subject, []byte("test data"))
ch, err := suite.streamingClient.OpenStream(suite.ctx, subject, "", []byte("test data"))
suite.Require().NoError(err)
suite.Require().NotNil(ch)
s.ch = ch
Expand Down
100 changes: 71 additions & 29 deletions pkg/nats/stream/consumer_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"sort"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -32,21 +33,25 @@ const (
// streamingBucket is a structure to hold the response channel and context
type streamingBucket struct {
// ctx is the context for the channel consumer that requested and waiting for messages
ctx context.Context
token string
ch chan *concurrency.AsyncResult[[]byte]
cancel context.CancelFunc
closeOnce sync.Once
ctx context.Context
token string
createdAt time.Time
producerConnID string
ch chan *concurrency.AsyncResult[[]byte]
cancel context.CancelFunc
closeOnce sync.Once
}

// newStreamingBucket creates a new streamingBucket.
func newStreamingBucket(ctx context.Context, token string) *streamingBucket {
func newStreamingBucket(ctx context.Context, token string, producerConnID string) *streamingBucket {
ctx, cancel := context.WithCancel(ctx)
return &streamingBucket{
ctx: ctx,
cancel: cancel,
token: token,
ch: make(chan *concurrency.AsyncResult[[]byte], RequestChanLen),
ctx: ctx,
cancel: cancel,
createdAt: time.Now(),
producerConnID: producerConnID,
token: token,
ch: make(chan *concurrency.AsyncResult[[]byte], RequestChanLen),
}
}

Expand All @@ -59,12 +64,12 @@ func (sb *streamingBucket) close() {
}

type ConsumerClientParams struct {
Conn *nats.Conn
HeartBeatRequestSub string
Conn *nats.Conn
}

// ConsumerClient represents a NATS streaming client.
type ConsumerClient struct {
ID string
Conn *nats.Conn
mu sync.RWMutex // Protects access to the response map.

Expand All @@ -84,16 +89,18 @@ type ConsumerClient struct {
// NewConsumerClient creates a new NATS client.
func NewConsumerClient(params ConsumerClientParams) (*ConsumerClient, error) {
nc := &ConsumerClient{
Conn: params.Conn,
respMap: make(map[string]*streamingBucket),
respRand: rand.New(rand.NewSource(time.Now().UnixNano())), //nolint:gosec // using same inbox naming as nats
heartBeatRequestSub: params.HeartBeatRequestSub,
ID: params.Conn.Opts.Name,
Conn: params.Conn,
respMap: make(map[string]*streamingBucket),
respRand: rand.New(rand.NewSource(time.Now().UnixNano())), //nolint:gosec // using same inbox naming as nats
}

// Setup response subscription.
nc.respSubPrefix = fmt.Sprintf("%s.", nc.newInbox())
newInbox := nc.newInbox()
nc.respSubPrefix = fmt.Sprintf("%s.", newInbox)
nc.respSubLen = len(nc.respSubPrefix)
nc.respSub = fmt.Sprintf("%s*", nc.respSubPrefix)
nc.heartBeatRequestSub = fmt.Sprintf("%s.%s", nc.Conn.Opts.Name, newInbox)

// Create the response subscription we will use for all streaming responses.
// This will be on an _SINBOX with an additional terminal token. The subscription
Expand Down Expand Up @@ -229,7 +236,7 @@ func (nc *ConsumerClient) respToken(respInbox string) string {

// OpenStream takes a context, a subject and payload
// in bytes and expects a channel with multiple responses.
func (nc *ConsumerClient) OpenStream(ctx context.Context, subj string, data []byte) (<-chan *concurrency.AsyncResult[[]byte], error) {
func (nc *ConsumerClient) OpenStream(ctx context.Context, subj string, producerConnId string, data []byte) (<-chan *concurrency.AsyncResult[[]byte], error) {
if ctx == nil {
return nil, nats.ErrInvalidContext
}
Expand All @@ -242,49 +249,59 @@ func (nc *ConsumerClient) OpenStream(ctx context.Context, subj string, data []by
return nil, ctx.Err()
}

bucket, err := nc.createNewRequestAndSend(ctx, subj, data)
bucket, err := nc.createNewRequestAndSend(ctx, subj, producerConnId, data)
if err != nil {
return nil, err
}
return bucket.ch, nil
}

func (nc *ConsumerClient) heartBeatRespHandler(msg *nats.Msg) {
nc.mu.Lock()
defer nc.mu.Unlock()
streamIds := make([]string, 0, len(nc.respMap))
for k := range nc.respMap {
streamIds = append(streamIds, k)
request := new(HeartBeatRequest)
err := json.Unmarshal(msg.Data, request)
if err != nil {
log.Err(err)
return
}

log.Info().Msgf("Request = %s", request)
var nonRecentStreamIds []string
for k, v := range nc.respMap {
if v.producerConnID == request.ProducerConnID && time.Since(v.createdAt) > 10*time.Second {
nonRecentStreamIds = append(nonRecentStreamIds, k)
}
}

data, err := json.Marshal(HeartBeatResponse{StreamIds: streamIds})
data, err := json.Marshal(ConsumerHeartBeatResponse{NonActiveStreamIds: Difference(nonRecentStreamIds, request.ActiveStreamIds)})
if err != nil {
log.Err(err)
return
}

err = nc.Conn.Publish(msg.Reply, data)
if err != nil {
log.Err(err)
return
}

}

// createNewRequestAndSend sets up and sends a new request, returning the response bucket.
func (nc *ConsumerClient) createNewRequestAndSend(ctx context.Context, subj string, data []byte) (*streamingBucket, error) {
func (nc *ConsumerClient) createNewRequestAndSend(ctx context.Context, subj string, producerConnId string, data []byte) (*streamingBucket, error) {
nc.mu.Lock()

// Create new literal Inbox and map to a bucket.
respInbox := nc.newRespInbox()
token := respInbox[nc.respSubLen:]
bucket := newStreamingBucket(ctx, token)
bucket := newStreamingBucket(ctx, token, producerConnId)

nc.respMap[token] = bucket
nc.mu.Unlock()

streamRequest := Request{
ConnectionDetails: ConnectionDetails{
ConnId: nc.Conn.Opts.Name,
StreamId: token,
ConnID: nc.Conn.Opts.Name,
StreamID: token,
HeartBeatRequestSub: nc.heartBeatRequestSub,
},
Data: data,
Expand Down Expand Up @@ -315,3 +332,28 @@ func (nc *ConsumerClient) NewWriter(subject string) *Writer {
subject: subject,
}
}

func Difference(a, b []string) []string {
i, j := 0, 0
var diff []string

sort.Strings(a)
sort.Strings(b)

for i < len(a) && j < len(b) {
if a[i] < b[j] {
diff = append(diff, a[i])
i++
} else if b[j] < a[i] {
j++
} else {
i++
j++
}
}
for ; i < len(a); i++ {
diff = append(diff, a[i])
}

return diff
}
2 changes: 1 addition & 1 deletion pkg/nats/stream/consumer_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (suite *ClientTestSuite) TestRequestWithContextCancellation() {
cancel()

// Attempt to make the request
_, err := suite.streamingClient.OpenStream(ctx, subj, payload)
_, err := suite.streamingClient.OpenStream(ctx, subj, "", payload)
suite.Require().Error(err, "Expected an error due to cancelled context")
}

Expand Down
Loading

0 comments on commit 8afa6ea

Please sign in to comment.