diff --git a/ziti/edge/msg_mux.go b/ziti/edge/msg_mux.go index 67b6b6f7..65535e7f 100644 --- a/ziti/edge/msg_mux.go +++ b/ziti/edge/msg_mux.go @@ -22,11 +22,11 @@ import ( "github.com/openziti/channel/v4" "github.com/openziti/sdk-golang/inspect" "github.com/openziti/sdk-golang/xgress" + cmap "github.com/orcaman/concurrent-map/v2" "github.com/pkg/errors" "github.com/sirupsen/logrus" "math" "strings" - "sync" "sync/atomic" "time" ) @@ -47,28 +47,28 @@ type MsgMux interface { GetNextId() uint32 } -func NewCowMapMsgMux() MsgMux { - result := &CowMapMsgMux{ +func NewMapMsgMux() MsgMux { + result := &MsgMuxImpl{ maxId: (math.MaxUint32 / 2) - 1, + sinks: cmap.NewWithCustomShardingFunction[uint32, MsgSink](func(key uint32) uint32 { + return key + }), } - result.sinks.Store(map[uint32]MsgSink{}) return result } -type CowMapMsgMux struct { - sync.Mutex +type MsgMuxImpl struct { closed atomic.Bool - sinks atomic.Value + sinks cmap.ConcurrentMap[uint32, MsgSink] nextId uint32 minId uint32 maxId uint32 } -func (mux *CowMapMsgMux) GetNextId() uint32 { +func (mux *MsgMuxImpl) GetNextId() uint32 { nextId := atomic.AddUint32(&mux.nextId, 1) - sinks := mux.getSinks() for { - if _, found := sinks[nextId]; found { + if _, found := mux.sinks.Get(nextId); found { // if it's in use, try next one nextId = atomic.AddUint32(&mux.nextId, 1) } else if nextId < mux.minId || nextId >= mux.maxId { @@ -82,11 +82,11 @@ func (mux *CowMapMsgMux) GetNextId() uint32 { } } -func (mux *CowMapMsgMux) ContentType() int32 { +func (mux *MsgMuxImpl) ContentType() int32 { return ContentTypeData } -func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel) { +func (mux *MsgMuxImpl) HandleReceive(msg *channel.Message, ch channel.Channel) { connId, found := msg.GetUint32Header(ConnIdHeader) if !found { if msg.ContentType == ContentTypeInspectRequest { @@ -97,22 +97,45 @@ func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel) return } - sinks := mux.getSinks() - if sink, found := sinks[connId]; found { + if sink, found := mux.sinks.Get(connId); found { sink.Accept(msg) } else if msg.ContentType == ContentTypeConnInspectRequest { - pfxlog.Logger().WithField("connId", connId).Trace("no conn found for connection inspect") + pfxlog.Logger().WithField("connId", int(connId)).Trace("no conn found for connection inspect") resp := NewConnInspectResponse(connId, ConnTypeInvalid, fmt.Sprintf("invalid conn id [%v]", connId)) if err := resp.ReplyTo(msg).Send(ch); err != nil { logrus.WithFields(GetLoggerFields(msg)).WithError(err). Error("failed to send inspect response") } + } else if msg.ContentType == ContentTypeXgPayload { + mux.handlePayloadWithNoSink(msg, ch) + } else if msg.ContentType == ContentTypeStateClosed { + // ignore, as conn is already closed } else { - pfxlog.Logger().Debugf("unable to dispatch msg received for unknown edge conn id: %v", connId) + pfxlog.Logger().WithField("connId", connId).WithField("contentType", msg.ContentType). + Debug("unable to dispatch msg received for unknown edge conn id") } } -func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel) { +func (mux *MsgMuxImpl) handlePayloadWithNoSink(msg *channel.Message, ch channel.Channel) { + connId, _ := msg.GetUint32Header(ConnIdHeader) + payload, err := xgress.UnmarshallPayload(msg) + if err == nil { + if payload.IsCircuitEndFlagSet() && len(payload.Data) == 0 { + ack := xgress.NewAcknowledgement(payload.CircuitId, payload.GetOriginator().Invert()) + ackMsg := ack.Marshall() + ackMsg.PutUint32Header(ConnIdHeader, connId) + _, _ = ch.TrySend(msg) + } else { + pfxlog.Logger().WithField("connId", int(connId)).WithField("circuitId", payload.CircuitId). + Debug("unable to dispatch xg payload received for unknown edge conn id") + } + } else { + pfxlog.Logger().WithError(err).WithField("connId", int(connId)). + Debug("unable to dispatch xg payload received for unknown edge conn id") + } +} + +func (mux *MsgMuxImpl) HandleInspect(msg *channel.Message, ch channel.Channel) { resp := &inspect.SdkInspectResponse{ Success: true, Values: make(map[string]any), @@ -132,7 +155,7 @@ func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel) Circuits: make(map[string]*xgress.CircuitDetail), } - for _, sink := range mux.getSinks() { + for _, sink := range mux.sinks.Items() { if circuitInfoSrc, ok := sink.(interface { GetCircuitDetail() *xgress.CircuitDetail }); ok { @@ -149,7 +172,7 @@ func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel) mux.returnInspectResponse(msg, ch, resp) } -func (mux *CowMapMsgMux) returnInspectResponse(msg *channel.Message, ch channel.Channel, resp *inspect.SdkInspectResponse) { +func (mux *MsgMuxImpl) returnInspectResponse(msg *channel.Message, ch channel.Channel, resp *inspect.SdkInspectResponse) { var sender channel.Sender = ch if mc, ok := ch.(channel.MultiChannel); ok { if sdkChan, ok := mc.GetUnderlayHandler().(SdkChannel); ok { @@ -169,61 +192,35 @@ func (mux *CowMapMsgMux) returnInspectResponse(msg *channel.Message, ch channel. } } -func (mux *CowMapMsgMux) HandleClose(channel.Channel) { +func (mux *MsgMuxImpl) HandleClose(channel.Channel) { mux.Close() } -func (mux *CowMapMsgMux) AddMsgSink(sink MsgSink) error { +func (mux *MsgMuxImpl) AddMsgSink(sink MsgSink) error { if mux.closed.Load() { return errors.Errorf("mux is closed, can't add sink with id [%v]", sink.Id()) } - var err error - mux.updateSinkMap(func(m map[uint32]MsgSink) { - if _, found := m[sink.Id()]; found { - err = errors.Errorf("sink id %v already in use", sink.Id()) - } else { - m[sink.Id()] = sink - } - }) - - // check again, just in case it was closed while we were adding - if mux.closed.Load() { - return errors.Errorf("mux is closed, can't add sink with id [%v]", sink.Id()) + if !mux.sinks.SetIfAbsent(sink.Id(), sink) { + return errors.Errorf("sink id %v already in use", sink.Id()) } - - return err + return nil } -func (mux *CowMapMsgMux) RemoveMsgSink(sink MsgSink) { +func (mux *MsgMuxImpl) RemoveMsgSink(sink MsgSink) { mux.RemoveMsgSinkById(sink.Id()) } -func (mux *CowMapMsgMux) RemoveMsgSinkById(sinkId uint32) { - mux.updateSinkMap(func(m map[uint32]MsgSink) { - delete(m, sinkId) - }) +func (mux *MsgMuxImpl) RemoveMsgSinkById(sinkId uint32) { + mux.sinks.Remove(sinkId) } -func (mux *CowMapMsgMux) updateSinkMap(f func(map[uint32]MsgSink)) { - mux.Lock() - defer mux.Unlock() - - current := mux.getSinks() - result := map[uint32]MsgSink{} - for k, v := range current { - result[k] = v - } - f(result) - mux.sinks.Store(result) -} - -func (mux *CowMapMsgMux) Close() { +func (mux *MsgMuxImpl) Close() { if mux.closed.CompareAndSwap(false, true) { // we don't need to lock the mux because due to the atomic bool, only one go-routine will enter this. // If the sink HandleMuxClose methods do anything with the mux, like remove themselves, they will acquire // their own locks - sinks := mux.getSinks() + sinks := mux.sinks.Items() for _, val := range sinks { if err := val.HandleMuxClose(); err != nil { pfxlog.Logger(). @@ -234,7 +231,3 @@ func (mux *CowMapMsgMux) Close() { } } } - -func (mux *CowMapMsgMux) getSinks() map[uint32]MsgSink { - return mux.sinks.Load().(map[uint32]MsgSink) -} diff --git a/ziti/edge/network/conn_test.go b/ziti/edge/network/conn_test.go index 435804ac..de50777b 100644 --- a/ziti/edge/network/conn_test.go +++ b/ziti/edge/network/conn_test.go @@ -32,7 +32,7 @@ func BenchmarkConnWrite(b *testing.B) { closeNotify := make(chan struct{}) defer close(closeNotify) - mux := edge.NewCowMapMsgMux() + mux := edge.NewMapMsgMux() testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{}) conn := &edgeConn{ MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1), @@ -58,7 +58,7 @@ func BenchmarkConnRead(b *testing.B) { closeNotify := make(chan struct{}) defer close(closeNotify) - mux := edge.NewCowMapMsgMux() + mux := edge.NewMapMsgMux() testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{}) readQ := NewNoopSequencer[*channel.Message](closeNotify, 4) @@ -135,7 +135,7 @@ func TestReadMultipart(t *testing.T) { closeNotify := make(chan struct{}) defer close(closeNotify) - mux := edge.NewCowMapMsgMux() + mux := edge.NewMapMsgMux() testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{}) readQ := NewNoopSequencer[*channel.Message](closeNotify, 4) diff --git a/ziti/edge/network/factory.go b/ziti/edge/network/factory.go index 33ff91aa..f0d3b684 100644 --- a/ziti/edge/network/factory.go +++ b/ziti/edge/network/factory.go @@ -65,7 +65,7 @@ func NewEdgeConnFactory(routerName, key string, owner RouterConnOwner) edge.Rout connFactory := &routerConn{ key: key, routerName: routerName, - msgMux: edge.NewCowMapMsgMux(), + msgMux: edge.NewMapMsgMux(), owner: owner, }