From 25f4360b57c8676978df7c604f8c9867db87ff5c Mon Sep 17 00:00:00 2001 From: Paul Lorenz Date: Sat, 7 Jun 2025 00:48:23 -0400 Subject: [PATCH] Rework xgress to allow rx and tx sides to be closed independently. Fixes #765 --- CHANGELOG.md | 19 +- version | 2 +- xgress/link_send_buffer.go | 78 ++++- xgress/messages.go | 10 + xgress/minimal_payload_test.go | 3 +- xgress/ordering_test.go | 3 +- xgress/write_timeout_test.go | 163 +++++++++++ xgress/xgress.go | 472 +++++++++++++++++++++--------- ziti/edge/msg_mux.go | 2 +- ziti/edge/network/conn.go | 112 ++++--- ziti/edge/network/hosting_conn.go | 5 +- ziti/edge/network/seq.go | 35 ++- ziti/edge/network/seq_test.go | 7 + ziti/edge/network/xg_adapter.go | 72 +++-- ziti/sdkinfo/build_info.go | 2 +- ziti/ziti.go | 2 +- 16 files changed, 747 insertions(+), 240 deletions(-) create mode 100644 xgress/write_timeout_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index d35af9a2..9fff8f6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,23 @@ -# Release notes 1.1.3 +# Release notes 1.2.0 + +## What's New + +This release contains substantial revisions to the SDK flow control feature first released in v1.1.0. +See the v1.1.0 release notes for more details. + +It has now received a substantial amount of testing including long running tests and backwards compability testing. + +These features should be used with version 1.6.6 or newer of OpenZiti. + +It is still considered experimental, and the feature and APIs may still change, however Go SDK +users who are multi-plexing connections, are encouraged to try it out. + +Once it has undergone sufficient soak time in a production environment, it will marked as stable. ## Issues Fixed and Dependency Updates -* github.com/openziti/sdk-golang: [v1.1.2 -> v1.1.3](https://github.com/openziti/sdk-golang/compare/v1.1.2...v1.1.3) +* github.com/openziti/sdk-golang: [v1.1.2 -> v1.2.0](https://github.com/openziti/sdk-golang/compare/v1.1.2...v1.2.0) + * [Issue #765](https://github.com/openziti/sdk-golang/issues/765) - Allow independent close of xgress send and receive * [Issue #763](https://github.com/openziti/sdk-golang/issues/763) - Use a go-routine pool for payload ingest * [Issue #761](https://github.com/openziti/sdk-golang/issues/761) - Use cmap.ConcurrentMap for message multiplexer * [Issue #754](https://github.com/openziti/sdk-golang/issues/754) - panic: unaligned 64-bit atomic operation when running on 32-bit raspberry pi diff --git a/version b/version index 9459d4ba..5625e59d 100644 --- a/version +++ b/version @@ -1 +1 @@ -1.1 +1.2 diff --git a/xgress/link_send_buffer.go b/xgress/link_send_buffer.go index 08049cab..c14dccc8 100644 --- a/xgress/link_send_buffer.go +++ b/xgress/link_send_buffer.go @@ -17,10 +17,11 @@ package xgress import ( + "context" "github.com/michaelquigley/pfxlog" - "github.com/pkg/errors" "github.com/sirupsen/logrus" "math" + "os" "slices" "sync/atomic" "time" @@ -52,6 +53,7 @@ type LinkSendBuffer struct { closeWhenEmpty atomic.Bool inspectRequests chan *sendBufferInspectEvent blockedSince time.Time + closeStart time.Time } type txPayload struct { @@ -111,22 +113,39 @@ func NewLinkSendBuffer(x *Xgress) *LinkSendBuffer { inspectRequests: make(chan *sendBufferInspectEvent, 1), } - go buffer.run() return buffer } func (buffer *LinkSendBuffer) CloseWhenEmpty() bool { + pfxlog.ContextLogger(buffer.x.Label()).Debug("close when empty") return buffer.closeWhenEmpty.CompareAndSwap(false, true) } func (buffer *LinkSendBuffer) BufferPayload(payload *Payload) (func(), error) { txPayload := &txPayload{payload: payload, age: math.MaxInt64, x: buffer.x} + select { case buffer.newlyBuffered <- txPayload: pfxlog.ContextLogger(buffer.x.Label()).Debugf("buffered [%d]", payload.GetSequence()) return txPayload.markSent, nil case <-buffer.closeNotify: - return nil, errors.Errorf("payload buffer closed") + return nil, ErrWriteClosed + } +} + +func (buffer *LinkSendBuffer) BufferPayloadWithDeadline(payload *Payload, ctx context.Context) (func(), error) { + txPayload := &txPayload{payload: payload, age: math.MaxInt64, x: buffer.x} + + for { + select { + case <-ctx.Done(): + return nil, os.ErrDeadlineExceeded + case buffer.newlyBuffered <- txPayload: + pfxlog.ContextLogger(buffer.x.Label()).Debugf("buffered [%d]", payload.GetSequence()) + return txPayload.markSent, nil + case <-buffer.closeNotify: + return nil, ErrWriteClosed + } } } @@ -151,10 +170,15 @@ func (buffer *LinkSendBuffer) metrics() Metrics { } func (buffer *LinkSendBuffer) Close() { - pfxlog.ContextLogger(buffer.x.Label()).Debugf("[%p] closing", buffer) if buffer.closed.CompareAndSwap(false, true) { + pfxlog.ContextLogger(buffer.x.Label()).Debugf("[%p] closing", buffer) close(buffer.closeNotify) } + buffer.x.closeIfRxAndTxDone() +} + +func (buffer *LinkSendBuffer) IsClosed() bool { + return buffer.closed.Load() } func (buffer *LinkSendBuffer) isBlocked() bool { @@ -211,7 +235,7 @@ func (buffer *LinkSendBuffer) run() { case ack := <-buffer.newlyReceivedAcks: buffer.receiveAcknowledgement(ack) case <-buffer.closeNotify: - buffer.close() + buffer.cleanupMetrics() return default: } @@ -232,7 +256,7 @@ func (buffer *LinkSendBuffer) run() { log.Tracef("buffering payload %v with size %v. payload buffer size: %v", txPayload.payload.Sequence, len(txPayload.payload.Data), buffer.linkSendBufferSize) case <-buffer.closeNotify: - buffer.close() + buffer.cleanupMetrics() return default: } @@ -245,9 +269,7 @@ func (buffer *LinkSendBuffer) run() { case ack := <-buffer.newlyReceivedAcks: buffer.receiveAcknowledgement(ack) buffer.retransmit() - if buffer.closeWhenEmpty.Load() && len(buffer.buffer) == 0 && !buffer.x.Closed() && buffer.x.IsEndOfCircuitSent() { - go buffer.x.Close() - } + buffer.checkForClose() case txPayload := <-buffered: buffer.buffer[txPayload.payload.GetSequence()] = txPayload @@ -259,15 +281,46 @@ func (buffer *LinkSendBuffer) run() { case <-retransmitTicker.C: buffer.retransmit() + buffer.checkForClose() case <-buffer.closeNotify: - buffer.close() + buffer.cleanupMetrics() + if len(buffer.buffer) > 0 { + isCircuitEnd := false + if len(buffer.buffer) == 1 { + for _, p := range buffer.buffer { + isCircuitEnd = p.payload.IsCircuitEndFlagSet() || p.payload.IsFlagEOFSet() + } + } + if !isCircuitEnd { + log.WithField("payloadCount", len(buffer.buffer)).Warn("closing while buffer contains unacked payloads") + } + } return } } } -func (buffer *LinkSendBuffer) close() { +func (buffer *LinkSendBuffer) checkForClose() { + if buffer.closeWhenEmpty.Load() { + if buffer.closeStart.IsZero() { + buffer.closeStart = time.Now() + } + closeDuration := time.Since(buffer.closeStart) + + if (len(buffer.buffer) == 0 && closeDuration > 5*time.Second) || closeDuration > buffer.x.Options.MaxCloseWait { + buffer.Close() + } else if len(buffer.buffer) == 1 && closeDuration > 5*time.Second { + for _, p := range buffer.buffer { + if p.payload.IsCircuitEndFlagSet() || p.payload.IsFlagEOFSet() { + buffer.Close() + } + } + } + } +} + +func (buffer *LinkSendBuffer) cleanupMetrics() { if buffer.blockedByLocalWindow { buffer.metrics().BufferUnblockedByLocalWindow() } @@ -358,7 +411,7 @@ func (buffer *LinkSendBuffer) retransmit() { } if retransmitted > 0 { - log.Debugf("retransmitted [%d] payloads, [%d] buffered, linkSendBufferSize: %d", retransmitted, len(buffer.buffer), buffer.linkSendBufferSize) + log.WithField("circuitId", buffer.x.circuitId).Debugf("retransmitted [%d] payloads, [%d] buffered, linkSendBufferSize: %d", retransmitted, len(buffer.buffer), buffer.linkSendBufferSize) } buffer.lastRetransmitTime = now } @@ -379,6 +432,7 @@ func (buffer *LinkSendBuffer) inspect() *SendBufferDetail { timeSinceLastRetransmit := time.Duration(time.Now().UnixMilli()-buffer.lastRetransmitTime) * time.Millisecond result := &SendBufferDetail{ WindowSize: buffer.windowsSize, + QueuedPayloadCount: len(buffer.buffer), LinkSendBufferSize: buffer.linkSendBufferSize, LinkRecvBufferSize: buffer.linkRecvBufferSize, Accumulator: buffer.accumulator, diff --git a/xgress/messages.go b/xgress/messages.go index a61b014c..dcd997d9 100644 --- a/xgress/messages.go +++ b/xgress/messages.go @@ -78,6 +78,8 @@ const ( PayloadFlagCircuitStart Flag = 4 PayloadFlagChunk Flag = 8 PayloadFlagRetransmit Flag = 16 + PayloadFlagEOF Flag = 32 + PayloadFlagWriteFailed Flag = 64 ) func NewAcknowledgement(circuitId string, originator Originator) *Acknowledgement { @@ -308,6 +310,14 @@ func (payload *Payload) IsCircuitEndFlagSet() bool { return isFlagSet(payload.Flags, PayloadFlagCircuitEnd) } +func (payload *Payload) IsFlagEOFSet() bool { + return isFlagSet(payload.Flags, PayloadFlagEOF) +} + +func (payload *Payload) IsFlagWriteFailedSet() bool { + return isFlagSet(payload.Flags, PayloadFlagWriteFailed) +} + func (payload *Payload) IsCircuitStartFlagSet() bool { return isFlagSet(payload.Flags, PayloadFlagCircuitStart) } diff --git a/xgress/minimal_payload_test.go b/xgress/minimal_payload_test.go index 3ee25630..027b3e3e 100644 --- a/xgress/minimal_payload_test.go +++ b/xgress/minimal_payload_test.go @@ -1,6 +1,7 @@ package xgress import ( + "context" "encoding/binary" "errors" "fmt" @@ -177,7 +178,7 @@ func (self *testIntermediary) ForwardAcknowledgement(ack *Acknowledgement, addre self.acker.SendAck(ack, address) } -func (self *testIntermediary) ForwardPayload(payload *Payload, x *Xgress) { +func (self *testIntermediary) ForwardPayload(payload *Payload, x *Xgress, ctx context.Context) { m := payload.Marshall() self.payloadTransformer.Tx(m, nil) b, err := self.msgs.GetMarshaller()(m) diff --git a/xgress/ordering_test.go b/xgress/ordering_test.go index 42d732dd..767767ee 100644 --- a/xgress/ordering_test.go +++ b/xgress/ordering_test.go @@ -1,6 +1,7 @@ package xgress import ( + "context" "encoding/binary" "github.com/openziti/channel/v4" "github.com/stretchr/testify/require" @@ -64,7 +65,7 @@ func (n noopReceiveHandler) GetPayloadIngester() *PayloadIngester { func (n noopReceiveHandler) ForwardAcknowledgement(*Acknowledgement, Address) {} -func (n noopReceiveHandler) ForwardPayload(*Payload, *Xgress) {} +func (n noopReceiveHandler) ForwardPayload(*Payload, *Xgress, context.Context) {} func (n noopReceiveHandler) ForwardControlMessage(*Control, *Xgress) {} diff --git a/xgress/write_timeout_test.go b/xgress/write_timeout_test.go new file mode 100644 index 00000000..122e35fe --- /dev/null +++ b/xgress/write_timeout_test.go @@ -0,0 +1,163 @@ +package xgress + +import ( + "errors" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestWriteTimeout(t *testing.T) { + req := require.New(t) + + writeAdapter := NewWriteAdapter(nil) + req.NotNil(writeAdapter.Done()) + + // test setting deadline + start := time.Now() + err := writeAdapter.SetWriteDeadline(start.Add(250 * time.Millisecond)) + req.NoError(err) + + select { + case <-writeAdapter.Done(): + passed := time.Since(start) + req.True(passed >= 250*time.Millisecond, "expected at least 250ms, got %s", passed) + req.True(passed <= 350*time.Millisecond, "expected at most 350ms, got %s", passed) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test that deadline doesn't get reset on its own after timeout + start = time.Now() + select { + case <-writeAdapter.Done(): + passed := time.Since(start) + req.True(passed < 10*time.Millisecond, "expected at most 10ms, got %s", passed) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test resetting deadline + start = time.Now() + err = writeAdapter.SetWriteDeadline(start.Add(250 * time.Millisecond)) + req.NoError(err) + + select { + case <-writeAdapter.Done(): + passed := time.Since(start) + req.True(passed >= 250*time.Millisecond, "expected at least 250ms, got %s", passed) + req.True(passed <= 350*time.Millisecond, "expected at most 350ms, got %s", passed) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test that deadline doesn't get reset on its own after timeout + start = time.Now() + select { + case <-writeAdapter.Done(): + passed := time.Since(start) + req.True(passed < 10*time.Millisecond, "expected at most 10ms, got %s", passed) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test setting deadline asynchronously + start = time.Now() + err = writeAdapter.SetWriteDeadline(time.Time{}) + req.NoError(err) + + go func() { + time.Sleep(100 * time.Millisecond) + err = writeAdapter.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) + req.NoError(err) + }() + + select { + case <-writeAdapter.Done(): + passed := time.Since(start) + req.True(passed >= 300*time.Millisecond, "expected at least 300ms, got %s", passed) + req.True(passed <= 350*time.Millisecond, "expected at most 350ms, got %s", passed) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test that deadline doesn't get reset on its own after timeout + start = time.Now() + select { + case <-writeAdapter.Done(): + passed := time.Since(start) + req.True(passed < 10*time.Millisecond, "expected at most 10ms, got %s", passed) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test setting deadline and clearing it asynchronously + start = time.Now() + err = writeAdapter.SetWriteDeadline(start.Add(250 * time.Millisecond)) + req.NoError(err) + + go func() { + time.Sleep(100 * time.Millisecond) + err = writeAdapter.SetWriteDeadline(time.Time{}) + req.NoError(err) + }() + + select { + case <-writeAdapter.Done(): + req.Error(errors.New("timeout should not have fired")) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test setting deadline asynchronously and clearing it asynchronously + err = writeAdapter.SetWriteDeadline(time.Time{}) + req.NoError(err) + + go func() { + err = writeAdapter.SetWriteDeadline(time.Now().Add(250 * time.Millisecond)) + req.NoError(err) + time.Sleep(100 * time.Millisecond) + err = writeAdapter.SetWriteDeadline(time.Time{}) + req.NoError(err) + + }() + + select { + case <-writeAdapter.Done(): + req.Error(errors.New("timeout should not have fired")) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test setting deadline to the past + start = time.Now() + err = writeAdapter.SetWriteDeadline(start.Add(-1 * time.Second)) + req.NoError(err) + + select { + case <-writeAdapter.Done(): + passed := time.Since(start) + req.True(passed < 10*time.Millisecond, "expected at most 10ms, got %s", passed) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } + + // test setting deadline to the past asynchronously + start = time.Now() + err = writeAdapter.SetWriteDeadline(time.Now()) + req.NoError(err) + + go func() { + time.Sleep(5 * time.Millisecond) + err = writeAdapter.SetWriteDeadline(time.Now().Add(-250 * time.Millisecond)) + req.NoError(err) + }() + + select { + case <-writeAdapter.Done(): + passed := time.Since(start) + req.True(passed < 20*time.Millisecond, "expected at most 20ms, got %s", passed) + case <-time.After(500 * time.Millisecond): + req.Error(errors.New("timeout didn't fire")) + } +} diff --git a/xgress/xgress.go b/xgress/xgress.go index f5379fac..e7d9ce6f 100644 --- a/xgress/xgress.go +++ b/xgress/xgress.go @@ -19,6 +19,7 @@ package xgress import ( "bufio" "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -45,8 +46,13 @@ const ( rxerStartedFlag = 1 endOfCircuitRecvdFlag = 2 endOfCircuitSentFlag = 3 + closedTxer = 4 + rxPushModeFlag = 5 // false == pull, use rx(), 1 == push, use WriteAdapter ) +var ErrWriteClosed = errors.New("write closed") +var ErrPeerClosed = errors.New("peer closed") + type Address string type AckSender interface { @@ -74,7 +80,7 @@ type Env interface { // is implemented to connect the xgress to a data plane data transmission system. type DataPlaneAdapter interface { // ForwardPayload is used to forward data payloads onto the data-plane from an xgress - ForwardPayload(payload *Payload, x *Xgress) + ForwardPayload(payload *Payload, x *Xgress, ctx context.Context) // RetransmitPayload is used to retransmit data payloads onto the data-plane from an xgress RetransmitPayload(srcAddr Address, payload *Payload) error @@ -114,9 +120,15 @@ type Connection interface { LogContext() string ReadPayload() ([]byte, map[uint8][]byte, error) WritePayload([]byte, map[uint8][]byte) (int, error) + HandleControlMsg(controlType ControlType, headers channel.Headers, responder ControlReceiver) error } +type SignalConnection interface { + Connection + FlowFromFabricToXgressClosed() +} + type Xgress struct { timeOfLastRxFromLink int64 // must be first for 64-bit atomic operations on 32-bit machines dataPlane DataPlaneAdapter @@ -220,6 +232,11 @@ func (self *Xgress) firstCircuitStartReceived() bool { return self.flags.CompareAndSet(rxerStartedFlag, false, true) } +func (self *Xgress) NewWriteAdapter() *WriteAdapter { + self.flags.Set(rxPushModeFlag, true) + return NewWriteAdapter(self) +} + func (self *Xgress) Start() { log := pfxlog.ContextLogger(self.Label()) if self.IsTerminator() { @@ -229,8 +246,12 @@ func (self *Xgress) Start() { } } else { log.Debug("initiator: sending circuit start") - self.forwardPayload(self.GetStartCircuit()) - go self.rx() + go self.payloadBuffer.run() + _ = self.forwardPayload(self.GetStartCircuit(), context.Background()) + + if !self.flags.IsSet(rxPushModeFlag) { + go self.rx() + } } go self.tx() } @@ -270,7 +291,6 @@ func (self *Xgress) ForwardEndOfCircuit(sendF func(payload *Payload) bool) { // for now always send end of circuit. too many is better than not enough if self.flags.CompareAndSet(endOfCircuitSentFlag, false, true) { sendF(self.GetEndCircuit()) - self.flags.Set(endOfCircuitSentFlag, true) } } @@ -278,40 +298,50 @@ func (self *Xgress) IsEndOfCircuitSent() bool { return self.flags.IsSet(endOfCircuitSentFlag) } -func (self *Xgress) CloseTimeout(duration time.Duration) { - if self.payloadBuffer.CloseWhenEmpty() { // If we clear the send buffer, close sooner - time.AfterFunc(duration, self.Close) - } +func (self *Xgress) CloseRxTimeout() { + self.sendEOF() + self.payloadBuffer.CloseWhenEmpty() } func (self *Xgress) Unrouted() { - // When we're unrouted, if end of circuit hasn't already arrived, give incoming/queued data - // a chance to outflow before closing - if !self.flags.IsSet(closedFlag) { - self.payloadBuffer.Close() - time.AfterFunc(self.Options.MaxCloseWait, self.Close) + // if we're unrouted no more data is inbound + self.CloseXgToClient() + + // When we're unrouted, if 'end of circuit' hasn't already arrived, give incoming/queued data + // a chance to outflow before closing. We're unrouted so no point in sending EOF + self.payloadBuffer.CloseWhenEmpty() +} + +func (self *Xgress) CloseXgToClient() { + pfxlog.ContextLogger(self.Label()).Debug("close xg to client") + if self.flags.CompareAndSet(closedTxer, false, true) { + close(self.closeNotify) + } + + if self.payloadBuffer.IsClosed() { + self.Close() } } /* -Things which can trigger close - -1. Read fails -2. Write fails -3. End of Circuit received -4. Unroute received +Close should only be called once both sides of the circuit are complete. */ func (self *Xgress) Close() { log := pfxlog.ContextLogger(self.Label()) if self.flags.CompareAndSet(closedFlag, false, true) { - log.Debug("closing xgress peer") + log.Debug("closing xgress") + + self.sendEndOfCircuit() + if err := self.peer.Close(); err != nil { log.WithError(err).Warn("error while closing xgress peer") } log.Debug("closing tx queue") - close(self.closeNotify) + if self.flags.CompareAndSet(closedTxer, false, true) { + close(self.closeNotify) + } self.payloadBuffer.Close() @@ -329,6 +359,19 @@ func (self *Xgress) Close() { } } +func (self *Xgress) PeerClosed() { + log := pfxlog.ContextLogger(self.Label()) + log.Debug("peer closed") + self.CloseXgToClient() + self.CloseRxTimeout() +} + +func (self *Xgress) closeIfRxAndTxDone() { + if self.payloadBuffer.IsClosed() && self.flags.IsSet(closedTxer) { + self.Close() + } +} + func (self *Xgress) CloseSendBuffer() { self.payloadBuffer.Close() } @@ -337,14 +380,24 @@ func (self *Xgress) Closed() bool { return self.flags.IsSet(closedFlag) } +func (self *Xgress) IsClosed() bool { + return self.flags.IsSet(closedFlag) +} + func (self *Xgress) SendPayload(payload *Payload, _ time.Duration, _ PayloadType) error { if self.Closed() { return nil } if payload.IsCircuitEndFlagSet() { - pfxlog.ContextLogger(self.Label()).Debug("received end of circuit Payload") + pfxlog.ContextLogger(self.Label()).Debug("received end of circuit payload") + } + + if payload.IsFlagWriteFailedSet() { + pfxlog.ContextLogger(self.Label()).Debug("received write failed payload") + self.payloadBuffer.Close() } + atomic.StoreInt64(&self.timeOfLastRxFromLink, time.Now().UnixMilli()) self.dataPlane.GetPayloadIngester().ingest(payload, self) @@ -372,7 +425,11 @@ func (self *Xgress) HandleControlReceive(controlType ControlType, headers channe func (self *Xgress) acceptPayload(payload *Payload) { if payload.IsCircuitStartFlagSet() && self.firstCircuitStartReceived() { - go self.rx() + pfxlog.ContextLogger(self.Label()).Debug("start received") + go self.payloadBuffer.run() + if !self.flags.IsSet(rxPushModeFlag) { + go self.rx() + } } if !self.Options.RandomDrops || rand.Int31n(self.Options.Drop1InN) != 1 { @@ -386,10 +443,11 @@ func (self *Xgress) tx() { log.Debug("started") defer log.Debug("exited") defer func() { - if self.IsEndOfCircuitReceived() { - self.Close() - } else { - self.flushSendThenClose() + if signalConn, ok := self.peer.(SignalConnection); ok { + signalConn.FlowFromFabricToXgressClosed() + } + if !self.IsEndOfCircuitReceived() { + self.sendWriteFailed() } }() @@ -409,8 +467,9 @@ func (self *Xgress) tx() { sendPayload := func(payload *Payload) bool { payloadLogger := log.WithFields(payload.GetLoggerFields()) - if payload.IsCircuitEndFlagSet() { + if payload.IsCircuitEndFlagSet() || payload.IsFlagEOFSet() { self.markCircuitEndReceived() + self.CloseXgToClient() payloadLogger.Debug("circuit end payload received, exiting") return false } @@ -508,17 +567,48 @@ func (self *Xgress) tx() { } } -func (self *Xgress) flushSendThenClose() { - self.CloseTimeout(self.Options.MaxCloseWait) - self.ForwardEndOfCircuit(func(payload *Payload) bool { - if self.payloadBuffer.closed.Load() { - // Avoid spurious 'failed to forward payload' error if the buffer is already closed - return false - } +func (self *Xgress) sendEOF() { + log := pfxlog.ContextLogger(self.Label()) + log.Debug("sendEOF") + + if self.payloadBuffer.closed.Load() { + // Avoid spurious 'failed to forward payload' error if the buffer is already closed + return + } + + payload := &Payload{ + CircuitId: self.circuitId, + Flags: SetOriginatorFlag(uint32(PayloadFlagEOF), self.originator), + Sequence: int32(self.nextReceiveSequence()), + Data: nil, + } - pfxlog.ContextLogger(self.Label()).Info("sending end of circuit payload") - return self.forwardPayload(payload) - }) + _ = self.forwardPayload(payload, context.Background()) +} + +func (self *Xgress) sendWriteFailed() { + log := pfxlog.ContextLogger(self.Label()) + log.Debug("sendWriteFailed") + + if self.payloadBuffer.closed.Load() { + return + } + + payload := &Payload{ + CircuitId: self.circuitId, + Flags: SetOriginatorFlag(uint32(PayloadFlagWriteFailed), self.originator), + Sequence: int32(self.nextReceiveSequence()), + Data: nil, + } + + log.Debug("sending end of circuit payload") + _ = self.forwardPayload(payload, context.Background()) +} + +func (self *Xgress) sendEndOfCircuit() { + log := pfxlog.ContextLogger(self.Label()) + log.Debug("sendEndOfCircuit") + self.dataPlane.ForwardPayload(self.GetEndCircuit(), self, context.Background()) } /** @@ -592,7 +682,7 @@ func (self *Xgress) rx() { } }() - defer self.flushSendThenClose() + defer self.CloseRxTimeout() for { buffer, headers, err := self.peer.ReadPayload() @@ -601,7 +691,10 @@ func (self *Xgress) rx() { // if we got an EOF, but also some data, ignore the EOF, next read we'll get 0, EOF if err != nil && (n == 0 || err != io.EOF) { - if err == io.EOF { + if err == io.EOF || errors.Is(err, ErrPeerClosed) { + if errors.Is(err, ErrPeerClosed) { // if the peer closed, we need to close the txer as well + self.CloseXgToClient() + } log.Debugf("EOF, exiting xgress.rx loop") } else { log.Warnf("read failed (%s)", err) @@ -614,117 +707,132 @@ func (self *Xgress) rx() { return } - if self.Options.Mtu == 0 { - if !self.sendUnchunkedBuffer(buffer, headers) { - return - } - continue + if err = self.Write(buffer, headers, nil); err != nil { + return } + } +} - first := true - chunked := false - for len(buffer) > 0 || (first && len(headers) > 0) { - seq := self.nextReceiveSequence() +func (self *Xgress) Write(buffer []byte, headers map[uint8][]byte, ctx context.Context) error { + log := pfxlog.ContextLogger(self.Label()) - chunk := make([]byte, self.Options.Mtu) + log.Debugf("payload read: %d bytes read", len(buffer)) + n := len(buffer) - flagsHeader := VersionMask & (PayloadProtocolV1 << PayloadProtocolOffset) - var sizesHeader byte - if self.originator == Terminator { - flagsHeader |= TerminatorFlagMask - } + if self.Closed() { + return ErrWriteClosed + } - written := 2 - rest := chunk[2:] - includeRtt := seq%5 == 0 - if includeRtt { - flagsHeader |= RttFlagMask - written += 2 - rest = rest[2:] - } + if self.Options.Mtu == 0 { + return self.sendUnchunkedBuffer(buffer, headers, ctx) + } - size := copy(rest, self.circuitId) - sizesHeader |= CircuitIdSizeMask & uint8(size) - written += size - rest = rest[size:] - size = binary.PutUvarint(rest, seq) - rest = rest[size:] - written += size + first := true + chunked := false + var err error - if first && len(headers) > 0 { - flagsHeader |= HeadersFlagMask - size, err = writeU8ToBytesMap(headers, rest) - if err != nil { - log.WithError(err).Error("payload encoding error, closing") - return - } - rest = rest[size:] - written += size - } + for len(buffer) > 0 || (first && len(headers) > 0) { + seq := self.nextReceiveSequence() - data := rest - dataLen := 0 - if first && len(rest) < len(buffer) { - chunked = true - size = binary.PutUvarint(rest, uint64(n)) - dataLen += size - written += size - rest = rest[size:] - } + chunk := make([]byte, self.Options.Mtu) - if chunked { - flagsHeader |= ChunkFlagMask - } + flagsHeader := VersionMask & (PayloadProtocolV1 << PayloadProtocolOffset) + var sizesHeader byte + if self.originator == Terminator { + flagsHeader |= TerminatorFlagMask + } + + written := 2 + rest := chunk[2:] + includeRtt := seq%5 == 0 + if includeRtt { + flagsHeader |= RttFlagMask + written += 2 + rest = rest[2:] + } - size = copy(rest, buffer) + size := copy(rest, self.circuitId) + sizesHeader |= CircuitIdSizeMask & uint8(size) + written += size + rest = rest[size:] + size = binary.PutUvarint(rest, seq) + rest = rest[size:] + written += size + + if first && len(headers) > 0 { + flagsHeader |= HeadersFlagMask + size, err = writeU8ToBytesMap(headers, rest) + if err != nil { + log.WithError(err).Error("payload encoding error, closing") + return err + } + rest = rest[size:] written += size + } + + data := rest + dataLen := 0 + if first && len(rest) < len(buffer) { + chunked = true + size = binary.PutUvarint(rest, uint64(n)) dataLen += size + written += size + rest = rest[size:] + } - buffer = buffer[size:] + if chunked { + flagsHeader |= ChunkFlagMask + } - // check if there's room for a heartbeat - if written+8 <= len(chunk) { - flagsHeader |= HeartbeatFlagMask - written += 8 - } + size = copy(rest, buffer) + written += size + dataLen += size - chunk[0] = flagsHeader - chunk[1] = sizesHeader + buffer = buffer[size:] - payload := &Payload{ - CircuitId: self.circuitId, - Flags: SetOriginatorFlag(0, self.originator), - Sequence: int32(seq), - Data: data[:dataLen], - raw: chunk[:written], - } + // check if there's room for a heartbeat + if written+8 <= len(chunk) { + flagsHeader |= HeartbeatFlagMask + written += 8 + } - if chunked { - payload.Flags = setPayloadFlag(payload.Flags, PayloadFlagChunk) - } + chunk[0] = flagsHeader + chunk[1] = sizesHeader - if first { - payload.Headers = headers - } + payload := &Payload{ + CircuitId: self.circuitId, + Flags: SetOriginatorFlag(0, self.originator), + Sequence: int32(seq), + Data: data[:dataLen], + raw: chunk[:written], + } - log.Debugf("sending payload chunk. seq: %d, first: %v, chunk size: %d, payload size: %d, remainder: %d", payload.Sequence, first, len(payload.Data), n, len(buffer)) - first = false + if chunked { + payload.Flags = setPayloadFlag(payload.Flags, PayloadFlagChunk) + } - // if the payload buffer is closed, we can't forward any more data, so might as well exit the rx loop - // The txer will still have a chance to flush any already received data - if !self.forwardPayload(payload) { - return - } + if first { + payload.Headers = headers + } - payloadLogger := log.WithFields(payload.GetLoggerFields()) - payloadLogger.Debugf("forwarded [%s]", info.ByteCount(int64(n))) + log.Debugf("sending payload chunk. seq: %d, first: %v, chunk size: %d, payload size: %d, remainder: %d", payload.Sequence, first, len(payload.Data), n, len(buffer)) + first = false + + // if the payload buffer is closed, we can't forward any more data, so might as well exit the rx loop + // The txer will still have a chance to flush any already received data + if err = self.forwardPayload(payload, ctx); err != nil { + return err } - logrus.Debugf("received payload for [%d] bytes", n) + payloadLogger := log.WithFields(payload.GetLoggerFields()) + payloadLogger.Debugf("forwarded [%s]", info.ByteCount(int64(n))) } + + logrus.Debugf("received payload for [%d] bytes", n) + return nil } -func (self *Xgress) sendUnchunkedBuffer(buf []byte, headers map[uint8][]byte) bool { +func (self *Xgress) sendUnchunkedBuffer(buf []byte, headers map[uint8][]byte, ctx context.Context) error { log := pfxlog.ContextLogger(self.Label()) payload := &Payload{ @@ -739,30 +847,39 @@ func (self *Xgress) sendUnchunkedBuffer(buf []byte, headers map[uint8][]byte) bo // if the payload buffer is closed, we can't forward any more data, so might as well exit the rx loop // The txer will still have a chance to flush any already received data - if !self.forwardPayload(payload) { - return false + if err := self.forwardPayload(payload, ctx); err != nil { + return err } payloadLogger := log.WithFields(payload.GetLoggerFields()) payloadLogger.Debugf("forwarded [%s]", info.ByteCount(int64(len(buf)))) - return true + return nil } -func (self *Xgress) forwardPayload(payload *Payload) bool { - sendCallback, err := self.payloadBuffer.BufferPayload(payload) +func (self *Xgress) forwardPayload(payload *Payload, ctx context.Context) error { + var sendCallback func() + var err error + + if ctx == nil { + sendCallback, err = self.payloadBuffer.BufferPayload(payload) + } else { + sendCallback, err = self.payloadBuffer.BufferPayloadWithDeadline(payload, ctx) + } if err != nil { - pfxlog.ContextLogger(self.Label()).WithError(err).Error("failure to buffer payload") - return false + if !payload.IsCircuitEndFlagSet() && !payload.IsFlagEOFSet() { + pfxlog.ContextLogger(self.Label()).WithError(err).Error("failure to buffer payload") + } + return err } for _, peekHandler := range self.peekHandlers { peekHandler.Rx(self, payload) } - self.dataPlane.ForwardPayload(payload, self) + self.dataPlane.ForwardPayload(payload, self, ctx) sendCallback() - return true + return nil } func (self *Xgress) nextReceiveSequence() uint64 { @@ -1015,3 +1132,90 @@ func readU8ToBytesMap(buf []byte) (map[uint8][]byte, []byte, error) { return result, buf, nil } + +func NewWriteAdapter(x *Xgress) *WriteAdapter { + result := &WriteAdapter{ + x: x, + } + result.doneNotify.Store(make(chan struct{})) + return result +} + +type WriteAdapter struct { + x *Xgress + deadline concurrenz.AtomicValue[time.Time] + doneNotify concurrenz.AtomicValue[chan struct{}] + doneNotifyClosed bool + lock sync.Mutex +} + +func (self *WriteAdapter) Deadline() (deadline time.Time, ok bool) { + deadline = self.deadline.Load() + return deadline, !deadline.IsZero() +} + +func (self *WriteAdapter) Done() <-chan struct{} { + return self.doneNotify.Load() +} + +func (self *WriteAdapter) Err() error { + return nil +} + +func (self *WriteAdapter) Value(any) any { + return nil +} + +func (self *WriteAdapter) SetWriteDeadline(t time.Time) error { + self.lock.Lock() + defer self.lock.Unlock() + + self.deadline.Store(t) + if t.IsZero() { + if self.doneNotifyClosed { + self.doneNotify.Store(make(chan struct{})) + self.doneNotifyClosed = false + } + return nil + } + d := time.Until(t) + if d > 0 { + if self.doneNotifyClosed { + self.doneNotify.Store(make(chan struct{})) + self.doneNotifyClosed = false + } + + time.AfterFunc(d, func() { + self.lock.Lock() + defer self.lock.Unlock() + + if t.Equal(self.deadline.Load()) { + if !self.doneNotifyClosed { + close(self.doneNotify.Load()) + self.doneNotifyClosed = true + } + } + }) + } else { + if !self.doneNotifyClosed { + close(self.doneNotify.Load()) + self.doneNotifyClosed = true + } + } + + return nil +} + +func (self *WriteAdapter) Write(b []byte) (n int, err error) { + if err = self.x.Write(b, nil, self); err != nil { + return 0, err + } + return len(b), nil +} + +func (self *WriteAdapter) WriteToXgress(b []byte, header map[uint8][]byte) (n int, err error) { + if err = self.x.Write(b, header, self); err != nil { + return 0, err + } + return len(b), nil +} diff --git a/ziti/edge/msg_mux.go b/ziti/edge/msg_mux.go index 65535e7f..38f26013 100644 --- a/ziti/edge/msg_mux.go +++ b/ziti/edge/msg_mux.go @@ -120,7 +120,7 @@ func (mux *MsgMuxImpl) handlePayloadWithNoSink(msg *channel.Message, ch channel. connId, _ := msg.GetUint32Header(ConnIdHeader) payload, err := xgress.UnmarshallPayload(msg) if err == nil { - if payload.IsCircuitEndFlagSet() && len(payload.Data) == 0 { + if (payload.IsCircuitEndFlagSet() || payload.IsFlagEOFSet()) && len(payload.Data) == 0 { ack := xgress.NewAcknowledgement(payload.CircuitId, payload.GetOriginator().Invert()) ackMsg := ack.Marshall() ackMsg.PutUint32Header(ConnIdHeader, connId) diff --git a/ziti/edge/network/conn.go b/ziti/edge/network/conn.go index 1817049b..e782ff9d 100644 --- a/ziti/edge/network/conn.go +++ b/ziti/edge/network/conn.go @@ -23,6 +23,7 @@ import ( "github.com/openziti/sdk-golang/xgress" "io" "net" + "strconv" "strings" "sync" "sync/atomic" @@ -44,13 +45,6 @@ import ( var unsupportedCrypto = errors.New("unsupported crypto") -type ConnType byte - -const ( - ConnTypeDial ConnType = 1 - ConnTypeBind ConnType = 2 -) - var _ edge.Conn = &edgeConn{} type edgeConn struct { @@ -84,6 +78,9 @@ type edgeConn struct { func (conn *edgeConn) Write(data []byte) (int, error) { if conn.sentFIN.Load() { + if conn.IsClosed() { + return 0, errors.New("connection closed") + } return 0, errors.New("calling Write() after CloseWrite()") } @@ -114,6 +111,10 @@ func (conn *edgeConn) CloseWrite() error { return err } + if conn.xgCircuit != nil { + conn.xgCircuit.xg.CloseRxTimeout() + } + return nil } @@ -177,6 +178,10 @@ func (conn *edgeConn) Accept(msg *channel.Message) { if conn.IsClosed() { return } + // routing is not accepting more data, so we need to close the send buffer + if conn.xgCircuit != nil { + conn.xgCircuit.xg.CloseSendBuffer() + } conn.sentFIN.Store(true) // if we're not closing until all reads are done, at least prevent more writes case edge.ContentTypeInspectRequest: @@ -218,18 +223,21 @@ func (conn *edgeConn) Accept(msg *channel.Message) { func (conn *edgeConn) HandleXgPayload(msg *channel.Message) { adapter := conn.xgCircuit + if adapter == nil { - // TODO: handle + pfxlog.Logger().WithField("circuitId", conn.circuitId).Error("can't accept payload, xgress adapter not present") return } payload, err := xgress.UnmarshallPayload(msg) if err != nil { + pfxlog.Logger().WithField("circuitId", conn.circuitId).WithError(err).Error("error unmarshalling payload") adapter.xg.Close() return } if err = adapter.xg.SendPayload(payload, 0, 0); err != nil { + pfxlog.Logger().WithField("circuitId", conn.circuitId).WithError(err).Error("error accepting payload") adapter.xg.Close() } } @@ -237,17 +245,19 @@ func (conn *edgeConn) HandleXgPayload(msg *channel.Message) { func (conn *edgeConn) HandleXgAcknowledgement(msg *channel.Message) { adapter := conn.xgCircuit if adapter == nil { - // TODO: handle + pfxlog.Logger().WithField("circuitId", conn.circuitId).Error("can't accept ack, xgress adapter not present") return } ack, err := xgress.UnmarshallAcknowledgement(msg) if err != nil { + pfxlog.Logger().WithField("circuitId", conn.circuitId).WithError(err).Error("error unmarshalling acknowledgement") adapter.xg.Close() return } if err = adapter.xg.SendAcknowledgement(ack); err != nil { + pfxlog.Logger().WithField("circuitId", conn.circuitId).WithError(err).Error("error accepting acknowledgement") adapter.xg.Close() } // adapter.env.GetAckIngester().Ingest(msg, adapter.xg) @@ -346,6 +356,13 @@ func (conn *edgeConn) SetDeadline(t time.Time) error { return conn.SetWriteDeadline(t) } +func (conn *edgeConn) SetWriteDeadline(t time.Time) error { + if conn.xgCircuit != nil { + return conn.xgCircuit.writeAdapter.SetWriteDeadline(t) + } + return conn.MsgChannel.SetWriteDeadline(t) +} + func (conn *edgeConn) SetReadDeadline(t time.Time) error { conn.readQ.SetReadDeadline(t) return nil @@ -370,13 +387,12 @@ func (conn *edgeConn) GetStickinessToken() []byte { } func (conn *edgeConn) HandleClose(channel.Channel) { - logger := pfxlog.Logger().WithField("connId", conn.Id()).WithField("marker", conn.marker) + logger := pfxlog.Logger().WithField("connId", conn.Id()).WithField("marker", conn.marker).WithField("circuitId", conn.circuitId) defer logger.Debug("received HandleClose from underlying channel, marking conn closed") - if conn.closed.CompareAndSwap(false, true) { - close(conn.closeNotify) + conn.close(true) + if conn.xgCircuit != nil { + conn.xgCircuit.xg.CloseSendBuffer() } - conn.sentFIN.Store(true) - conn.readFIN.Store(true) } func (conn *edgeConn) Connect(session *rest_model.SessionDetail, options *edge.DialOptions, envF func() xgress.Env) (edge.Conn, error) { @@ -459,19 +475,18 @@ func (conn *edgeConn) setupFlowControl(msg *channel.Message, originator xgress.O } xgAdapter := &XgAdapter{ - conn: conn, - readC: make(chan []byte), - closeNotify: conn.closeNotify, - env: envF(), + conn: conn, + readC: make(chan []byte), + env: envF(), } conn.xgCircuit = xgAdapter xg := xgress.NewXgress(conn.circuitId, ctrlId, xgress.Address(addr), xgAdapter, originator, xgress.DefaultOptions(), nil) xgAdapter.xg = xg - conn.dataSink = xgAdapter + xgAdapter.writeAdapter = xg.NewWriteAdapter() + xgAdapter.xg.AddCloseHandler(xgAdapter) + conn.dataSink = xgAdapter.writeAdapter xg.SetDataPlaneAdapter(xgAdapter) - xg.AddCloseHandler(xgAdapter) - xg.Start() } else { conn.dataSink = &conn.MsgChannel @@ -532,8 +547,12 @@ func (conn *edgeConn) establishServerCrypto(keypair *kx.KeyPair, peerKey []byte, } func (conn *edgeConn) Read(p []byte) (int, error) { - log := pfxlog.Logger().WithField("connId", conn.Id()).WithField("marker", conn.marker) + log := pfxlog.Logger().WithField("connId", conn.Id()). + WithField("marker", conn.marker). + WithField("circuitId", conn.circuitId) + if conn.closed.Load() { + log.Trace("edgeConn closed, returning EOF") return 0, io.EOF } @@ -553,21 +572,23 @@ func (conn *edgeConn) Read(p []byte) (int, error) { for { if conn.readFIN.Load() { + log.Tracef("readFIN true, returning EOF") return 0, io.EOF } msg, err := conn.readQ.GetNext() if errors.Is(err, ErrClosed) { - log.Debug("sequencer closed, closing connection") - conn.closed.Store(true) + log.Debug("sequencer closed, marking readFIN") + conn.readFIN.Store(true) return 0, io.EOF } else if err != nil { - log.Debugf("unexpected sequencer err (%v)", err) + log.WithError(err).Debug("unexpected sequencer err") return 0, err } flags, _ := msg.GetUint32Header(edge.FlagsHeader) if flags&edge.FIN != 0 { + log.Trace("got fin msg, marking readFIN true") conn.readFIN.Store(true) } conn.flags = conn.flags | (flags & (edge.STREAM | edge.MULTIPART)) @@ -575,8 +596,18 @@ func (conn *edgeConn) Read(p []byte) (int, error) { switch msg.ContentType { case edge.ContentTypeStateClosed: - log.Debug("received ConnState_CLOSED message, closing connection") - conn.close(true) + if conn.xgCircuit != nil { + conn.readFIN.Store(true) + if conn.sentFIN.Load() { + log.Debug("received ConnState_CLOSED message, fin sent, closing connection") + conn.close(true) + } else { + log.Debug("received ConnState_CLOSED message, fin not yet sent") + } + } else { + log.Debug("received ConnState_CLOSED message, closing connection") + conn.close(true) + } continue case edge.ContentTypeData: @@ -643,6 +674,7 @@ func (conn *edgeConn) Read(p []byte) (int, error) { } func (conn *edgeConn) Close() error { + pfxlog.Logger().WithField("connId", strconv.Itoa(int(conn.Id()))).WithField("circuitId", conn.circuitId).Debug("closing edge conn") conn.close(false) return nil } @@ -659,20 +691,26 @@ func (conn *edgeConn) close(closedByRemote bool) { conn.readFIN.Store(true) conn.sentFIN.Store(true) - log := pfxlog.Logger().WithField("connId", conn.Id()).WithField("marker", conn.marker) + log := pfxlog.Logger().WithField("connId", int(conn.Id())).WithField("marker", conn.marker).WithField("circuitId", conn.circuitId) + log.Debug("close: begin") defer log.Debug("close: end") - if !closedByRemote { - msg := edge.NewStateClosedMsg(conn.Id(), "") - if err := conn.SendState(msg); err != nil { - log.WithError(err).Error("failed to send close message") + if conn.xgCircuit == nil { + if !closedByRemote { + msg := edge.NewStateClosedMsg(conn.Id(), "") + if err := conn.SendState(msg); err != nil { + log.WithError(err).Error("failed to send close message") + } } - } - // if we're using xgress, wait to remove the conn from the mux until the xgress closes, otherwise it becomes unroutable. - if conn.xgCircuit == nil { conn.msgMux.RemoveMsgSink(conn) // if we switch back to ChMsgMux will need to be done async again, otherwise we may deadlock + } else { + // cancel any pending writes + _ = conn.xgCircuit.writeAdapter.SetWriteDeadline(time.Now()) + + // if we're using xgress, wait to remove the connection from the mux until the xgress closes, otherwise it becomes unroutable. + conn.xgCircuit.xg.PeerClosed() } } @@ -768,7 +806,7 @@ func (self *newConnHandler) dialSucceeded() error { if !self.routerProvidedConnId { startMsg, err := reply.WithPriority(channel.Highest).WithTimeout(5 * time.Second).SendForReply(self.conn.GetControlSender()) if err != nil { - logger.WithError(err).Error("Failed to send reply to dial request") + logger.WithError(err).Error("failed to send reply to dial request") return err } @@ -777,7 +815,7 @@ func (self *newConnHandler) dialSucceeded() error { return errors.Errorf("failed to receive start after dial. got %v", startMsg) } } else if err := reply.WithPriority(channel.Highest).WithTimeout(time.Second * 5).SendAndWaitForWire(self.conn.GetControlSender()); err != nil { - logger.WithError(err).Error("Failed to send reply to dial request") + logger.WithError(err).Error("failed to send reply to dial request") return err } diff --git a/ziti/edge/network/hosting_conn.go b/ziti/edge/network/hosting_conn.go index 90ff1302..a8b80515 100644 --- a/ziti/edge/network/hosting_conn.go +++ b/ziti/edge/network/hosting_conn.go @@ -57,7 +57,10 @@ func (conn *edgeHostConn) Accept(msg *channel.Message) { switch msg.ContentType { case edge.ContentTypeDial: newConnId, _ := msg.GetUint32Header(edge.RouterProvidedConnId) - logrus.WithFields(edge.GetLoggerFields(msg)).WithField("newConnId", newConnId).Debug("received dial request") + circuitId, _ := msg.GetStringHeader(edge.CircuitIdHeader) + logrus.WithFields(edge.GetLoggerFields(msg)). + WithField("circuitId", circuitId). + WithField("newConnId", newConnId).Debug("received dial request") go conn.newChildConnection(msg) case edge.ContentTypeStateClosed: conn.close(true) diff --git a/ziti/edge/network/seq.go b/ziti/edge/network/seq.go index a7ea3d28..3f8f6acb 100644 --- a/ziti/edge/network/seq.go +++ b/ziti/edge/network/seq.go @@ -25,25 +25,34 @@ func (r ReadTimout) Temporary() bool { func NewNoopSequencer[T any](closeNotify <-chan struct{}, channelDepth int) *noopSeq[T] { return &noopSeq[T]{ - closeNotify: closeNotify, - ch: make(chan T, channelDepth), - deadlineNotify: make(chan struct{}), + externalCloseNotify: closeNotify, + ch: make(chan T, channelDepth), + deadlineNotify: make(chan struct{}), + closeNotify: make(chan struct{}), } } type noopSeq[T any] struct { - ch chan T - closeNotify <-chan struct{} - deadlineNotify chan struct{} - deadline concurrenz.AtomicValue[time.Time] - readInProgress atomic.Bool + ch chan T + externalCloseNotify <-chan struct{} + deadlineNotify chan struct{} + closeNotify chan struct{} + closed atomic.Bool + deadline concurrenz.AtomicValue[time.Time] + readInProgress atomic.Bool +} + +func (self *noopSeq[T]) Close() { + if self.closed.CompareAndSwap(false, true) { + close(self.closeNotify) + } } func (seq *noopSeq[T]) PutSequenced(event T) error { select { case seq.ch <- event: return nil - case <-seq.closeNotify: + case <-seq.externalCloseNotify: return ErrClosed } } @@ -81,6 +90,14 @@ func (seq *noopSeq[T]) GetNext() (T, error) { select { case val = <-seq.ch: return val, nil + case <-seq.externalCloseNotify: + // If we're closed, return any buffered values, otherwise return nil + select { + case val = <-seq.ch: + return val, nil + default: + return val, ErrClosed + } case <-seq.closeNotify: // If we're closed, return any buffered values, otherwise return nil select { diff --git a/ziti/edge/network/seq_test.go b/ziti/edge/network/seq_test.go index 35705bee..b7e5a8af 100644 --- a/ziti/edge/network/seq_test.go +++ b/ziti/edge/network/seq_test.go @@ -1,9 +1,11 @@ package network import ( + "fmt" "github.com/openziti/channel/v4" "github.com/openziti/sdk-golang/ziti/edge" "github.com/stretchr/testify/require" + "math" "testing" "time" ) @@ -89,3 +91,8 @@ func Test_SeqReadWithInterrupt(t *testing.T) { req.ErrorIs(err, &ReadTimout{}) req.True(time.Since(first) < time.Millisecond) } + +func Test_GetMaxMsgMux(t *testing.T) { + maxId := (math.MaxUint32 / 2) - 1 + fmt.Printf("max id: %d\n", maxId) +} diff --git a/ziti/edge/network/xg_adapter.go b/ziti/edge/network/xg_adapter.go index 0774ac85..6180fdbb 100644 --- a/ziti/edge/network/xg_adapter.go +++ b/ziti/edge/network/xg_adapter.go @@ -1,6 +1,8 @@ package network import ( + "context" + "errors" "fmt" "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v4" @@ -8,45 +10,48 @@ import ( "github.com/openziti/sdk-golang/xgress" "github.com/openziti/sdk-golang/ziti/edge" "github.com/sirupsen/logrus" - "io" "time" ) type XgAdapter struct { - conn *edgeConn - readC chan []byte - closeNotify <-chan struct{} - env xgress.Env - xg *xgress.Xgress + conn *edgeConn + readC chan []byte + env xgress.Env + xg *xgress.Xgress + writeAdapter *xgress.WriteAdapter } func (self *XgAdapter) HandleXgressClose(x *xgress.Xgress) { - self.xg.ForwardEndOfCircuit(func(payload *xgress.Payload) bool { - self.ForwardPayload(payload, x) - return true - }) - self.conn.close(true) - - // see note in close - self.conn.msgMux.RemoveMsgSink(self.conn) - xgCloseMsg := channel.NewMessage(edge.ContentTypeXgClose, []byte(self.xg.CircuitId())) if err := xgCloseMsg.WithTimeout(5 * time.Second).Send(self.conn.GetControlSender()); err != nil { pfxlog.Logger().WithError(err).Error("failed to send close xg close message") } + + // see note in close + self.conn.msgMux.RemoveMsgSink(self.conn) } -func (self *XgAdapter) ForwardPayload(payload *xgress.Payload, x *xgress.Xgress) { +func (self *XgAdapter) ForwardPayload(payload *xgress.Payload, _ *xgress.Xgress, ctx context.Context) { msg := payload.Marshall() msg.PutUint32Header(edge.ConnIdHeader, self.conn.Id()) - if err := self.conn.MsgChannel.GetDefaultSender().Send(msg); err != nil { - pfxlog.Logger().WithError(err).Error("failed to send payload") + + if err := msg.WithContext(ctx).SendAndWaitForWire(self.conn.GetDefaultSender()); err != nil { + pfxlog.Logger().WithField("circuitId", payload.CircuitId).WithError(err).Error("failed to send payload") } } func (self *XgAdapter) RetransmitPayload(srcAddr xgress.Address, payload *xgress.Payload) error { msg := payload.Marshall() - return self.conn.MsgChannel.GetDefaultSender().Send(msg) + sent, err := self.conn.MsgChannel.GetDefaultSender().TrySend(msg) + if err != nil { + return err + } + + if !sent { + pfxlog.Logger().WithField("circuitId", payload.CircuitId).WithError(err).Error("payload dropped") + } + + return nil } func (self *XgAdapter) ForwardControlMessage(control *xgress.Control, x *xgress.Xgress) { @@ -76,33 +81,15 @@ func (self *XgAdapter) GetMetrics() xgress.Metrics { } func (self *XgAdapter) Close() error { - return self.conn.Close() + return nil } func (self *XgAdapter) LogContext() string { return fmt.Sprintf("xg/%s", self.conn.GetCircuitId()) } -func (self *XgAdapter) Write(bytes []byte) (int, error) { - select { - case self.readC <- bytes: - return len(bytes), nil - case <-self.closeNotify: - return 0, io.EOF - } -} - func (self *XgAdapter) ReadPayload() ([]byte, map[uint8][]byte, error) { - // log := pfxlog.ContextLogger(self.LogContext()).WithField("connId", self.conn.Id()) - - var data []byte - select { - case data = <-self.readC: - case <-self.closeNotify: - return nil, nil, io.EOF - } - - return data, nil, nil + return nil, nil, errors.New("should never be called") } func (self *XgAdapter) WritePayload(bytes []byte, headers map[uint8][]byte) (int, error) { @@ -131,6 +118,7 @@ func (self *XgAdapter) WritePayload(bytes []byte, headers map[uint8][]byte) (int if err := self.conn.readQ.PutSequenced(msg); err != nil { logrus.WithFields(edge.GetLoggerFields(msg)).WithError(err). + WithField("circuitId", self.conn.circuitId). Error("error pushing edge message to sequencer") return 0, err } @@ -139,6 +127,12 @@ func (self *XgAdapter) WritePayload(bytes []byte, headers map[uint8][]byte) (int return len(msg.Body), nil } +func (self *XgAdapter) FlowFromFabricToXgressClosed() { + pfxlog.Logger().WithField("circuitId", self.conn.circuitId). + Debug("fabric to sdk flow complete") + self.conn.readQ.Close() +} + func (self *XgAdapter) HandleControlMsg(controlType xgress.ControlType, headers channel.Headers, responder xgress.ControlReceiver) error { //TODO implement me panic("implement me") diff --git a/ziti/sdkinfo/build_info.go b/ziti/sdkinfo/build_info.go index 993bb04b..f6d099ed 100644 --- a/ziti/sdkinfo/build_info.go +++ b/ziti/sdkinfo/build_info.go @@ -20,5 +20,5 @@ package sdkinfo const ( - Version = "v1.1.3" + Version = "v1.2.0" ) diff --git a/ziti/ziti.go b/ziti/ziti.go index e6027776..eecb13cb 100644 --- a/ziti/ziti.go +++ b/ziti/ziti.go @@ -990,7 +990,7 @@ func (context *ContextImpl) RefreshApiSessionWithBackoff() error { logrus.Info("previous apiSession expired") return backoff.Permanent(err) } - logrus.WithError(err).Info("unable to refresh apiSession, will retry") + logrus.WithError(err).Infof("unable to refresh apiSession, error type %T, will retry", err) return err }