Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 52 additions & 59 deletions ziti/edge/msg_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import (
"github.com/openziti/channel/v4"
"github.com/openziti/sdk-golang/inspect"
"github.com/openziti/sdk-golang/xgress"
cmap "github.com/orcaman/concurrent-map/v2"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"math"
"strings"
"sync"
"sync/atomic"
"time"
)
Expand All @@ -47,28 +47,28 @@ type MsgMux interface {
GetNextId() uint32
}

func NewCowMapMsgMux() MsgMux {
result := &CowMapMsgMux{
func NewMapMsgMux() MsgMux {
result := &MsgMuxImpl{
maxId: (math.MaxUint32 / 2) - 1,
sinks: cmap.NewWithCustomShardingFunction[uint32, MsgSink](func(key uint32) uint32 {
return key
}),
}
result.sinks.Store(map[uint32]MsgSink{})
return result
}

type CowMapMsgMux struct {
sync.Mutex
type MsgMuxImpl struct {
closed atomic.Bool
sinks atomic.Value
sinks cmap.ConcurrentMap[uint32, MsgSink]
nextId uint32
minId uint32
maxId uint32
}

func (mux *CowMapMsgMux) GetNextId() uint32 {
func (mux *MsgMuxImpl) GetNextId() uint32 {
nextId := atomic.AddUint32(&mux.nextId, 1)
sinks := mux.getSinks()
for {
if _, found := sinks[nextId]; found {
if _, found := mux.sinks.Get(nextId); found {
// if it's in use, try next one
nextId = atomic.AddUint32(&mux.nextId, 1)
} else if nextId < mux.minId || nextId >= mux.maxId {
Expand All @@ -82,11 +82,11 @@ func (mux *CowMapMsgMux) GetNextId() uint32 {
}
}

func (mux *CowMapMsgMux) ContentType() int32 {
func (mux *MsgMuxImpl) ContentType() int32 {
return ContentTypeData
}

func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel) {
func (mux *MsgMuxImpl) HandleReceive(msg *channel.Message, ch channel.Channel) {
connId, found := msg.GetUint32Header(ConnIdHeader)
if !found {
if msg.ContentType == ContentTypeInspectRequest {
Expand All @@ -97,22 +97,45 @@ func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel)
return
}

sinks := mux.getSinks()
if sink, found := sinks[connId]; found {
if sink, found := mux.sinks.Get(connId); found {
sink.Accept(msg)
} else if msg.ContentType == ContentTypeConnInspectRequest {
pfxlog.Logger().WithField("connId", connId).Trace("no conn found for connection inspect")
pfxlog.Logger().WithField("connId", int(connId)).Trace("no conn found for connection inspect")
resp := NewConnInspectResponse(connId, ConnTypeInvalid, fmt.Sprintf("invalid conn id [%v]", connId))
if err := resp.ReplyTo(msg).Send(ch); err != nil {
logrus.WithFields(GetLoggerFields(msg)).WithError(err).
Error("failed to send inspect response")
}
} else if msg.ContentType == ContentTypeXgPayload {
mux.handlePayloadWithNoSink(msg, ch)
} else if msg.ContentType == ContentTypeStateClosed {
// ignore, as conn is already closed
} else {
pfxlog.Logger().Debugf("unable to dispatch msg received for unknown edge conn id: %v", connId)
pfxlog.Logger().WithField("connId", connId).WithField("contentType", msg.ContentType).
Debug("unable to dispatch msg received for unknown edge conn id")
}
}

func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel) {
func (mux *MsgMuxImpl) handlePayloadWithNoSink(msg *channel.Message, ch channel.Channel) {
connId, _ := msg.GetUint32Header(ConnIdHeader)
payload, err := xgress.UnmarshallPayload(msg)
if err == nil {
if payload.IsCircuitEndFlagSet() && len(payload.Data) == 0 {
ack := xgress.NewAcknowledgement(payload.CircuitId, payload.GetOriginator().Invert())
ackMsg := ack.Marshall()
ackMsg.PutUint32Header(ConnIdHeader, connId)
_, _ = ch.TrySend(msg)
} else {
pfxlog.Logger().WithField("connId", int(connId)).WithField("circuitId", payload.CircuitId).
Debug("unable to dispatch xg payload received for unknown edge conn id")
}
} else {
pfxlog.Logger().WithError(err).WithField("connId", int(connId)).
Debug("unable to dispatch xg payload received for unknown edge conn id")
}
}

func (mux *MsgMuxImpl) HandleInspect(msg *channel.Message, ch channel.Channel) {
resp := &inspect.SdkInspectResponse{
Success: true,
Values: make(map[string]any),
Expand All @@ -132,7 +155,7 @@ func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel)
Circuits: make(map[string]*xgress.CircuitDetail),
}

for _, sink := range mux.getSinks() {
for _, sink := range mux.sinks.Items() {
if circuitInfoSrc, ok := sink.(interface {
GetCircuitDetail() *xgress.CircuitDetail
}); ok {
Expand All @@ -149,7 +172,7 @@ func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel)
mux.returnInspectResponse(msg, ch, resp)
}

func (mux *CowMapMsgMux) returnInspectResponse(msg *channel.Message, ch channel.Channel, resp *inspect.SdkInspectResponse) {
func (mux *MsgMuxImpl) returnInspectResponse(msg *channel.Message, ch channel.Channel, resp *inspect.SdkInspectResponse) {
var sender channel.Sender = ch
if mc, ok := ch.(channel.MultiChannel); ok {
if sdkChan, ok := mc.GetUnderlayHandler().(SdkChannel); ok {
Expand All @@ -169,61 +192,35 @@ func (mux *CowMapMsgMux) returnInspectResponse(msg *channel.Message, ch channel.
}
}

func (mux *CowMapMsgMux) HandleClose(channel.Channel) {
func (mux *MsgMuxImpl) HandleClose(channel.Channel) {
mux.Close()
}

func (mux *CowMapMsgMux) AddMsgSink(sink MsgSink) error {
func (mux *MsgMuxImpl) AddMsgSink(sink MsgSink) error {
if mux.closed.Load() {
return errors.Errorf("mux is closed, can't add sink with id [%v]", sink.Id())
}

var err error
mux.updateSinkMap(func(m map[uint32]MsgSink) {
if _, found := m[sink.Id()]; found {
err = errors.Errorf("sink id %v already in use", sink.Id())
} else {
m[sink.Id()] = sink
}
})

// check again, just in case it was closed while we were adding
if mux.closed.Load() {
return errors.Errorf("mux is closed, can't add sink with id [%v]", sink.Id())
if !mux.sinks.SetIfAbsent(sink.Id(), sink) {
return errors.Errorf("sink id %v already in use", sink.Id())
}

return err
return nil
}

func (mux *CowMapMsgMux) RemoveMsgSink(sink MsgSink) {
func (mux *MsgMuxImpl) RemoveMsgSink(sink MsgSink) {
mux.RemoveMsgSinkById(sink.Id())
}

func (mux *CowMapMsgMux) RemoveMsgSinkById(sinkId uint32) {
mux.updateSinkMap(func(m map[uint32]MsgSink) {
delete(m, sinkId)
})
func (mux *MsgMuxImpl) RemoveMsgSinkById(sinkId uint32) {
mux.sinks.Remove(sinkId)
}

func (mux *CowMapMsgMux) updateSinkMap(f func(map[uint32]MsgSink)) {
mux.Lock()
defer mux.Unlock()

current := mux.getSinks()
result := map[uint32]MsgSink{}
for k, v := range current {
result[k] = v
}
f(result)
mux.sinks.Store(result)
}

func (mux *CowMapMsgMux) Close() {
func (mux *MsgMuxImpl) Close() {
if mux.closed.CompareAndSwap(false, true) {
// we don't need to lock the mux because due to the atomic bool, only one go-routine will enter this.
// If the sink HandleMuxClose methods do anything with the mux, like remove themselves, they will acquire
// their own locks
sinks := mux.getSinks()
sinks := mux.sinks.Items()
for _, val := range sinks {
if err := val.HandleMuxClose(); err != nil {
pfxlog.Logger().
Expand All @@ -234,7 +231,3 @@ func (mux *CowMapMsgMux) Close() {
}
}
}

func (mux *CowMapMsgMux) getSinks() map[uint32]MsgSink {
return mux.sinks.Load().(map[uint32]MsgSink)
}
6 changes: 3 additions & 3 deletions ziti/edge/network/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func BenchmarkConnWrite(b *testing.B) {
closeNotify := make(chan struct{})
defer close(closeNotify)

mux := edge.NewCowMapMsgMux()
mux := edge.NewMapMsgMux()
testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{})
conn := &edgeConn{
MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1),
Expand All @@ -58,7 +58,7 @@ func BenchmarkConnRead(b *testing.B) {
closeNotify := make(chan struct{})
defer close(closeNotify)

mux := edge.NewCowMapMsgMux()
mux := edge.NewMapMsgMux()
testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{})

readQ := NewNoopSequencer[*channel.Message](closeNotify, 4)
Expand Down Expand Up @@ -135,7 +135,7 @@ func TestReadMultipart(t *testing.T) {
closeNotify := make(chan struct{})
defer close(closeNotify)

mux := edge.NewCowMapMsgMux()
mux := edge.NewMapMsgMux()
testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{})

readQ := NewNoopSequencer[*channel.Message](closeNotify, 4)
Expand Down
2 changes: 1 addition & 1 deletion ziti/edge/network/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func NewEdgeConnFactory(routerName, key string, owner RouterConnOwner) edge.Rout
connFactory := &routerConn{
key: key,
routerName: routerName,
msgMux: edge.NewCowMapMsgMux(),
msgMux: edge.NewMapMsgMux(),
owner: owner,
}

Expand Down
Loading