diff --git a/xgress/circuit_inspections.go b/xgress/circuit_inspections.go new file mode 100644 index 00000000..09b99c3f --- /dev/null +++ b/xgress/circuit_inspections.go @@ -0,0 +1,94 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +type CircuitInspectDetail struct { + CircuitId string `json:"circuitId"` + Forwards map[string]string `json:"forwards"` + XgressDetails map[string]*InspectDetail `json:"xgressDetails"` + LinkDetails map[string]*LinkInspectDetail `json:"linkDetails"` + includeGoroutines bool +} + +func (self *CircuitInspectDetail) SetIncludeGoroutines(includeGoroutines bool) { + self.includeGoroutines = includeGoroutines +} + +func (self *CircuitInspectDetail) IncludeGoroutines() bool { + return self.includeGoroutines +} + +func (self *CircuitInspectDetail) AddXgressDetail(xgressDetail *InspectDetail) { + self.XgressDetails[xgressDetail.Address] = xgressDetail +} + +func (self *CircuitInspectDetail) AddLinkDetail(linkDetail *LinkInspectDetail) { + self.LinkDetails[linkDetail.Id] = linkDetail +} + +type InspectDetail struct { + Address string `json:"address"` + Originator string `json:"originator"` + TimeSinceLastLinkRx string `json:"timeSinceLastLinkRx"` + SendBufferDetail *SendBufferDetail `json:"sendBufferDetail"` + RecvBufferDetail *RecvBufferDetail `json:"recvBufferDetail"` + XgressPointer string `json:"xgressPointer"` + LinkSendBufferPointer string `json:"linkSendBufferPointer"` + Goroutines []string `json:"goroutines"` + Sequence uint64 `json:"sequence"` + Flags string `json:"flags"` +} + +type SendBufferDetail struct { + WindowSize uint32 `json:"windowSize"` + LinkSendBufferSize uint32 `json:"linkSendBufferSize"` + LinkRecvBufferSize uint32 `json:"linkRecvBufferSize"` + Accumulator uint32 `json:"accumulator"` + SuccessfulAcks uint32 `json:"successfulAcks"` + DuplicateAcks uint32 `json:"duplicateAcks"` + Retransmits uint32 `json:"retransmits"` + Closed bool `json:"closed"` + BlockedByLocalWindow bool `json:"blockedByLocalWindow"` + BlockedByRemoteWindow bool `json:"blockedByRemoteWindow"` + RetxScale float64 `json:"retxScale"` + RetxThreshold uint32 `json:"retxThreshold"` + TimeSinceLastRetx string `json:"timeSinceLastRetx"` + CloseWhenEmpty bool `json:"closeWhenEmpty"` + AcquiredSafely bool `json:"acquiredSafely"` +} + +type RecvBufferDetail struct { + Size uint32 `json:"size"` + PayloadCount uint32 `json:"payloadCount"` + LastSizeSent uint32 `json:"lastSizeSent"` + Sequence int32 `json:"sequence"` + MaxSequence int32 `json:"maxSequence"` + NextPayload string `json:"nextPayload"` + AcquiredSafely bool `json:"acquiredSafely"` +} + +type LinkInspectDetail struct { + Id string `json:"id"` + Iteration uint32 `json:"iteration"` + Key string `json:"key"` + Split bool `json:"split"` + Protocol string `json:"protocol"` + DialAddress string `json:"dialAddress"` + Dest string `json:"dest"` + DestVersion string `json:"destVersion"` + Dialed bool `json:"dialed"` +} diff --git a/xgress/decoder.go b/xgress/decoder.go new file mode 100644 index 00000000..225df1e5 --- /dev/null +++ b/xgress/decoder.go @@ -0,0 +1,121 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +import ( + "fmt" + "github.com/michaelquigley/pfxlog" + "github.com/openziti/channel/v4" +) + +type Decoder struct{} + +const DECODER = "data" + +func (d Decoder) Decode(msg *channel.Message) ([]byte, bool) { + switch msg.ContentType { + case int32(ContentTypePayloadType): + if payload, err := UnmarshallPayload(msg); err == nil { + return DecodePayload(payload) + } else { + pfxlog.Logger().WithError(err).Error("unexpected error unmarshalling payload msg") + } + + case int32(ContentTypeAcknowledgementType): + if ack, err := UnmarshallAcknowledgement(msg); err == nil { + meta := channel.NewTraceMessageDecode(DECODER, "Acknowledgement") + meta["circuitId"] = ack.CircuitId + meta["sequence"] = fmt.Sprintf("len(%d)", len(ack.Sequence)) + switch ack.GetOriginator() { + case Initiator: + meta["originator"] = "i" + case Terminator: + meta["originator"] = "e" + } + + data, err := meta.MarshalTraceMessageDecode() + if err != nil { + return nil, true + } + + return data, true + + } else { + pfxlog.Logger().WithError(err).Error("unexpected error unmarshalling ack msg") + } + case int32(ContentTypeControlType): + if control, err := UnmarshallControl(msg); err == nil { + meta := channel.NewTraceMessageDecode(DECODER, "Control") + meta["circuitId"] = control.CircuitId + meta["type"] = control.Type.String() + if control.Type == ControlTypeTraceRoute || control.Type == ControlTypeTraceRouteResponse { + if ts, found := msg.GetUint64Header(ControlTimestamp); found { + meta["ts"] = ts + } + if hop, found := msg.GetUint32Header(ControlHopCount); found { + meta["hopCount"] = hop + } + if hopType, found := msg.GetStringHeader(ControlHopType); found { + meta["hopType"] = hopType + } + if hopId, found := msg.GetStringHeader(ControlHopId); found { + meta["hopId"] = hopId + } + if userVal, found := msg.GetUint32Header(ControlUserVal); found { + meta["uv"] = userVal + } + if hopErr, found := msg.GetUint32Header(ControlError); found { + meta["err"] = hopErr + } + } + data, err := meta.MarshalTraceMessageDecode() + if err != nil { + return nil, true + } + + return data, true + + } else { + pfxlog.Logger().WithError(err).Error("unexpected error unmarshalling control msg") + } + } + + return nil, false +} + +func DecodePayload(payload *Payload) ([]byte, bool) { + meta := channel.NewTraceMessageDecode(DECODER, "Payload") + meta["circuitId"] = payload.CircuitId + meta["sequence"] = payload.Sequence + switch payload.GetOriginator() { + case Initiator: + meta["originator"] = "i" + case Terminator: + meta["originator"] = "e" + } + if payload.Flags != 0 { + meta["flags"] = payload.Flags + } + meta["length"] = len(payload.Data) + + data, err := meta.MarshalTraceMessageDecode() + if err != nil { + return nil, true + } + + return data, true +} diff --git a/xgress/heartbeat_transformer.go b/xgress/heartbeat_transformer.go new file mode 100644 index 00000000..17abe2aa --- /dev/null +++ b/xgress/heartbeat_transformer.go @@ -0,0 +1,38 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +import ( + "encoding/binary" + "github.com/openziti/channel/v4" + "time" +) + +type PayloadTransformer struct { +} + +func (self PayloadTransformer) Rx(*channel.Message, channel.Channel) {} + +func (self PayloadTransformer) Tx(m *channel.Message, ch channel.Channel) { + if m.ContentType == channel.ContentTypeRaw && len(m.Body) > 1 { + if m.Body[0]&HeartbeatFlagMask != 0 && len(m.Body) > 12 { + now := time.Now().UnixNano() + m.PutUint64Header(channel.HeartbeatHeader, uint64(now)) + binary.BigEndian.PutUint64(m.Body[len(m.Body)-8:], uint64(now)) + } + } +} diff --git a/xgress/link_receive_buffer.go b/xgress/link_receive_buffer.go new file mode 100644 index 00000000..5a010ae6 --- /dev/null +++ b/xgress/link_receive_buffer.go @@ -0,0 +1,147 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +import ( + "fmt" + "github.com/emirpasic/gods/trees/btree" + "github.com/emirpasic/gods/utils" + "github.com/michaelquigley/pfxlog" + "sync/atomic" + "time" +) + +type LinkReceiveBuffer struct { + tree *btree.Tree + sequence int32 + maxSequence int32 + size uint32 + lastBufferSizeSent uint32 +} + +func NewLinkReceiveBuffer() *LinkReceiveBuffer { + return &LinkReceiveBuffer{ + tree: btree.NewWith(10240, utils.Int32Comparator), + sequence: -1, + } +} + +func (buffer *LinkReceiveBuffer) Size() uint32 { + return atomic.LoadUint32(&buffer.size) +} + +func (buffer *LinkReceiveBuffer) ReceiveUnordered(x *Xgress, payload *Payload, maxSize uint32) bool { + if payload.GetSequence() <= buffer.sequence { + x.dataPlane.GetMetrics().MarkDuplicatePayload() + return true + } + + if atomic.LoadUint32(&buffer.size) > maxSize && payload.Sequence > buffer.maxSequence { + x.dataPlane.GetMetrics().MarkPayloadDropped() + return false + } + + treeSize := buffer.tree.Size() + buffer.tree.Put(payload.GetSequence(), payload) + if buffer.tree.Size() > treeSize { + payloadSize := len(payload.Data) + size := atomic.AddUint32(&buffer.size, uint32(payloadSize)) + pfxlog.Logger().Tracef("Payload %v of size %v added to transmit buffer. New size: %v", payload.Sequence, payloadSize, size) + if payload.Sequence > buffer.maxSequence { + buffer.maxSequence = payload.Sequence + } + } else { + x.dataPlane.GetMetrics().MarkDuplicatePayload() + } + return true +} + +func (buffer *LinkReceiveBuffer) PeekHead() *Payload { + if val := buffer.tree.LeftValue(); val != nil { + payload := val.(*Payload) + if payload.Sequence == buffer.sequence+1 { + return payload + } + } + return nil +} + +func (buffer *LinkReceiveBuffer) Remove(payload *Payload) { + buffer.tree.Remove(payload.Sequence) + buffer.sequence = payload.Sequence +} + +func (buffer *LinkReceiveBuffer) getLastBufferSizeSent() uint32 { + return atomic.LoadUint32(&buffer.lastBufferSizeSent) +} + +func (buffer *LinkReceiveBuffer) Inspect(x *Xgress) *RecvBufferDetail { + timeout := time.After(100 * time.Millisecond) + inspectEvent := &receiveBufferInspectEvent{ + buffer: buffer, + notifyComplete: make(chan *RecvBufferDetail, 1), + } + + if x.dataPlane.GetPayloadIngester().inspect(inspectEvent, timeout) { + select { + case result := <-inspectEvent.notifyComplete: + return result + case <-timeout: + } + } + + return buffer.inspectIncomplete() +} + +func (buffer *LinkReceiveBuffer) inspectComplete() *RecvBufferDetail { + nextPayload := "none" + if head := buffer.tree.LeftValue(); head != nil { + payload := head.(*Payload) + nextPayload = fmt.Sprintf("%v", payload.Sequence) + } + + return &RecvBufferDetail{ + Size: buffer.Size(), + PayloadCount: uint32(buffer.tree.Size()), + LastSizeSent: buffer.getLastBufferSizeSent(), + Sequence: buffer.sequence, + MaxSequence: buffer.maxSequence, + NextPayload: nextPayload, + AcquiredSafely: true, + } +} + +func (buffer *LinkReceiveBuffer) inspectIncomplete() *RecvBufferDetail { + return &RecvBufferDetail{ + Size: buffer.Size(), + LastSizeSent: buffer.getLastBufferSizeSent(), + Sequence: buffer.sequence, + MaxSequence: buffer.maxSequence, + NextPayload: "unsafe to check", + AcquiredSafely: false, + } +} + +type receiveBufferInspectEvent struct { + buffer *LinkReceiveBuffer + notifyComplete chan *RecvBufferDetail +} + +func (self *receiveBufferInspectEvent) handle() { + result := self.buffer.inspectComplete() + self.notifyComplete <- result +} diff --git a/xgress/link_send_buffer.go b/xgress/link_send_buffer.go new file mode 100644 index 00000000..ed145d81 --- /dev/null +++ b/xgress/link_send_buffer.go @@ -0,0 +1,428 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +import ( + "github.com/michaelquigley/pfxlog" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "math" + "slices" + "sync/atomic" + "time" +) + +// Note: if altering this struct, be sure to account for 64 bit alignment on 32 bit arm arch +// https://pkg.go.dev/sync/atomic#pkg-note-BUG +// https://github.com/golang/go/issues/36606 +type LinkSendBuffer struct { + retxScale float64 + x *Xgress + buffer map[int32]*txPayload + newlyBuffered chan *txPayload + newlyReceivedAcks chan *Acknowledgement + windowsSize uint32 + linkSendBufferSize uint32 + linkRecvBufferSize uint32 + accumulator uint32 + successfulAcks uint32 + duplicateAcks uint32 + retransmits uint32 + closeNotify chan struct{} + closed atomic.Bool + blockedByLocalWindow bool + blockedByRemoteWindow bool + retxThreshold uint32 + lastRtt uint16 + lastRetransmitTime int64 + closeWhenEmpty atomic.Bool + inspectRequests chan *sendBufferInspectEvent + blockedSince time.Time +} + +type txPayload struct { + age int64 + payload *Payload + retxQueued int32 + x *Xgress + next *txPayload + prev *txPayload +} + +func (self *txPayload) markSent() { + atomic.StoreInt64(&self.age, time.Now().UnixMilli()) +} + +func (self *txPayload) getAge() int64 { + return atomic.LoadInt64(&self.age) +} + +func (self *txPayload) markQueued() { + atomic.AddInt32(&self.retxQueued, 1) +} + +// markAcked marks the payload and acked and returns true if the payload is queued for retransmission +func (self *txPayload) markAcked() bool { + return atomic.AddInt32(&self.retxQueued, 2) > 2 +} + +func (self *txPayload) dequeued() { + atomic.AddInt32(&self.retxQueued, -1) +} + +func (self *txPayload) isAcked() bool { + return atomic.LoadInt32(&self.retxQueued) > 1 +} + +func (self *txPayload) isRetransmittable() bool { + return atomic.LoadInt32(&self.retxQueued) == 0 +} + +func NewLinkSendBuffer(x *Xgress) *LinkSendBuffer { + logrus.Debugf("txPortalStartSize = %d, txPortalMinSize = %d", + x.Options.TxPortalStartSize, + x.Options.TxPortalMinSize) + + // newlyBuffered should be size 0, otherwise payloads can be sent and acks received before the payload is + // processed by the LinkSendBuffer + buffer := &LinkSendBuffer{ + x: x, + buffer: make(map[int32]*txPayload), + newlyBuffered: make(chan *txPayload), + newlyReceivedAcks: make(chan *Acknowledgement, 2), + closeNotify: make(chan struct{}), + windowsSize: x.Options.TxPortalStartSize, + retxThreshold: x.Options.RetxStartMs, + retxScale: x.Options.RetxScale, + inspectRequests: make(chan *sendBufferInspectEvent, 1), + } + + go buffer.run() + return buffer +} + +func (buffer *LinkSendBuffer) CloseWhenEmpty() bool { + return buffer.closeWhenEmpty.CompareAndSwap(false, true) +} + +func (buffer *LinkSendBuffer) BufferPayload(payload *Payload) (func(), error) { + txPayload := &txPayload{payload: payload, age: math.MaxInt64, x: buffer.x} + select { + case buffer.newlyBuffered <- txPayload: + pfxlog.ContextLogger(buffer.x.Label()).Debugf("buffered [%d]", payload.GetSequence()) + return txPayload.markSent, nil + case <-buffer.closeNotify: + return nil, errors.Errorf("payload buffer closed") + } +} + +func (buffer *LinkSendBuffer) ReceiveAcknowledgement(ack *Acknowledgement) { + log := pfxlog.ContextLogger(buffer.x.Label()).WithFields(ack.GetLoggerFields()) + log.Debug("ack received") + select { + case buffer.newlyReceivedAcks <- ack: + log.Debug("ack processed") + case <-buffer.closeNotify: + // if end of circuit was received, we've cleanly shutdown and can ignore any trailing acks + if buffer.x.IsEndOfCircuitReceived() { + log.Debug("payload buffer closed") + } else { + log.Error("payload buffer closed") + } + } +} + +func (buffer *LinkSendBuffer) metrics() Metrics { + return buffer.x.dataPlane.GetMetrics() +} + +func (buffer *LinkSendBuffer) Close() { + pfxlog.ContextLogger(buffer.x.Label()).Debugf("[%p] closing", buffer) + if buffer.closed.CompareAndSwap(false, true) { + close(buffer.closeNotify) + } +} + +func (buffer *LinkSendBuffer) isBlocked() bool { + wasBlocked := buffer.blockedByLocalWindow || buffer.blockedByRemoteWindow + blocked := false + + if buffer.x.Options.TxPortalMaxSize < buffer.linkRecvBufferSize { + blocked = true + if !buffer.blockedByRemoteWindow { + buffer.blockedByRemoteWindow = true + buffer.metrics().BufferBlockedByRemoteWindow() + } + } else if buffer.blockedByRemoteWindow { + buffer.blockedByRemoteWindow = false + buffer.metrics().BufferUnblockedByRemoteWindow() + } + + if buffer.windowsSize < buffer.linkSendBufferSize { + blocked = true + if !buffer.blockedByLocalWindow { + buffer.blockedByLocalWindow = true + buffer.metrics().BufferBlockedByLocalWindow() + } + } else if buffer.blockedByLocalWindow { + buffer.blockedByLocalWindow = false + buffer.metrics().BufferUnblockedByLocalWindow() + } + + if blocked { + if !wasBlocked { + buffer.blockedSince = time.Now() + } + pfxlog.ContextLogger(buffer.x.Label()).Debugf("blocked=%v win_size=%v tx_buffer_size=%v rx_buffer_size=%v", blocked, buffer.windowsSize, buffer.linkSendBufferSize, buffer.linkRecvBufferSize) + } else if wasBlocked { + buffer.metrics().BufferUnblocked(time.Since(buffer.blockedSince)) + } + + return blocked +} + +func (buffer *LinkSendBuffer) run() { + log := pfxlog.ContextLogger(buffer.x.Label()) + defer log.Debugf("[%p] exited", buffer) + log.Debugf("[%p] started", buffer) + + var buffered chan *txPayload + + retransmitTicker := time.NewTicker(100 * time.Millisecond) + defer retransmitTicker.Stop() + + for { + // bias acks, process all pending, since that should not block + select { + case ack := <-buffer.newlyReceivedAcks: + buffer.receiveAcknowledgement(ack) + case <-buffer.closeNotify: + buffer.close() + return + default: + } + + // don't block when we're closing, since the only thing that should still be coming in is end-of-circuit + // if we're blocked, but empty, let one payload in to reduce the chances of a stall + if buffer.isBlocked() && !buffer.closeWhenEmpty.Load() && buffer.linkSendBufferSize != 0 { + buffered = nil + } else { + buffered = buffer.newlyBuffered + + select { + case txPayload := <-buffered: + buffer.buffer[txPayload.payload.GetSequence()] = txPayload + payloadSize := len(txPayload.payload.Data) + buffer.linkSendBufferSize += uint32(payloadSize) + buffer.metrics().SendPayloadBuffered(int64(payloadSize)) + log.Tracef("buffering payload %v with size %v. payload buffer size: %v", + txPayload.payload.Sequence, len(txPayload.payload.Data), buffer.linkSendBufferSize) + case <-buffer.closeNotify: + buffer.close() + return + default: + } + } + + select { + case inspectEvent := <-buffer.inspectRequests: + inspectEvent.handle(buffer) + + case ack := <-buffer.newlyReceivedAcks: + buffer.receiveAcknowledgement(ack) + buffer.retransmit() + if buffer.closeWhenEmpty.Load() && len(buffer.buffer) == 0 && !buffer.x.Closed() && buffer.x.IsEndOfCircuitSent() { + go buffer.x.Close() + } + + case txPayload := <-buffered: + buffer.buffer[txPayload.payload.GetSequence()] = txPayload + payloadSize := len(txPayload.payload.Data) + buffer.linkSendBufferSize += uint32(payloadSize) + buffer.metrics().SendPayloadBuffered(int64(payloadSize)) + log.Tracef("buffering payload %v with size %v. payload buffer size: %v", + txPayload.payload.Sequence, len(txPayload.payload.Data), buffer.linkSendBufferSize) + + case <-retransmitTicker.C: + buffer.retransmit() + + case <-buffer.closeNotify: + buffer.close() + return + } + } +} + +func (buffer *LinkSendBuffer) close() { + if buffer.blockedByLocalWindow { + buffer.metrics().BufferUnblockedByLocalWindow() + } + if buffer.blockedByRemoteWindow { + buffer.metrics().BufferUnblockedByRemoteWindow() + } +} + +func (buffer *LinkSendBuffer) receiveAcknowledgement(ack *Acknowledgement) { + log := pfxlog.ContextLogger(buffer.x.Label()).WithFields(ack.GetLoggerFields()) + + for _, sequence := range ack.Sequence { + if txPayload, found := buffer.buffer[sequence]; found { + if txPayload.markAcked() { // if it's been queued for retransmission, remove it from the queue + buffer.x.dataPlane.GetRetransmitter().queue(txPayload) + } + + payloadSize := uint32(len(txPayload.payload.Data)) + buffer.accumulator += payloadSize + buffer.successfulAcks++ + delete(buffer.buffer, sequence) + buffer.metrics().SendPayloadDelivered(int64(payloadSize)) + buffer.linkSendBufferSize -= payloadSize + log.Debugf("removing payload %v with size %v. payload buffer size: %v", + txPayload.payload.Sequence, len(txPayload.payload.Data), buffer.linkSendBufferSize) + + if buffer.successfulAcks >= buffer.x.Options.TxPortalIncreaseThresh { + buffer.successfulAcks = 0 + delta := uint32(float64(buffer.accumulator) * buffer.x.Options.TxPortalIncreaseScale) + buffer.windowsSize += delta + if buffer.windowsSize > buffer.x.Options.TxPortalMaxSize { + buffer.windowsSize = buffer.x.Options.TxPortalMaxSize + } + buffer.retxScale -= 0.01 + if buffer.retxScale < buffer.x.Options.RetxScale { + buffer.retxScale = buffer.x.Options.RetxScale + } + } + } else { // duplicate ack + buffer.metrics().MarkDuplicateAck() + buffer.duplicateAcks++ + if buffer.duplicateAcks >= buffer.x.Options.TxPortalDupAckThresh { + buffer.duplicateAcks = 0 + buffer.retxScale += 0.2 + } + } + } + + buffer.linkRecvBufferSize = ack.RecvBufferSize + if ack.RTT > 0 { + rtt := uint16(time.Now().UnixMilli()) - ack.RTT + if buffer.lastRtt > 0 { + rtt = (rtt + buffer.lastRtt) >> 1 + } + buffer.lastRtt = rtt + buffer.retxThreshold = uint32(float64(rtt)*buffer.retxScale) + buffer.x.Options.RetxAddMs + } +} + +func (buffer *LinkSendBuffer) retransmit() { + now := time.Now().UnixMilli() + if len(buffer.buffer) > 0 && (now-buffer.lastRetransmitTime) > 64 { + log := pfxlog.ContextLogger(buffer.x.Label()) + + retransmitted := 0 + var rtxList []*txPayload + for _, v := range buffer.buffer { + age := v.getAge() + if age != math.MaxInt64 && v.isRetransmittable() && uint32(now-age) >= buffer.retxThreshold { + rtxList = append(rtxList, v) + } + } + + slices.SortFunc(rtxList, func(a, b *txPayload) int { + return int(a.payload.Sequence - b.payload.Sequence) + }) + + for _, v := range rtxList { + v.markQueued() + buffer.x.dataPlane.GetRetransmitter().queue(v) + retransmitted++ + buffer.retransmits++ + if buffer.retransmits >= buffer.x.Options.TxPortalRetxThresh { + buffer.accumulator = 0 + buffer.retransmits = 0 + buffer.scale(buffer.x.Options.TxPortalRetxScale) + } + } + + if retransmitted > 0 { + log.Debugf("retransmitted [%d] payloads, [%d] buffered, linkSendBufferSize: %d", retransmitted, len(buffer.buffer), buffer.linkSendBufferSize) + } + buffer.lastRetransmitTime = now + } +} + +func (buffer *LinkSendBuffer) scale(factor float64) { + buffer.windowsSize = uint32(float64(buffer.windowsSize) * factor) + if factor > 1 { + if buffer.windowsSize > buffer.x.Options.TxPortalMaxSize { + buffer.windowsSize = buffer.x.Options.TxPortalMaxSize + } + } else if buffer.windowsSize < buffer.x.Options.TxPortalMinSize { + buffer.windowsSize = buffer.x.Options.TxPortalMinSize + } +} + +func (buffer *LinkSendBuffer) inspect() *SendBufferDetail { + timeSinceLastRetransmit := time.Duration(time.Now().UnixMilli()-buffer.lastRetransmitTime) * time.Millisecond + result := &SendBufferDetail{ + WindowSize: buffer.windowsSize, + LinkSendBufferSize: buffer.linkSendBufferSize, + LinkRecvBufferSize: buffer.linkRecvBufferSize, + Accumulator: buffer.accumulator, + SuccessfulAcks: buffer.successfulAcks, + DuplicateAcks: buffer.duplicateAcks, + Retransmits: buffer.retransmits, + Closed: buffer.closed.Load(), + BlockedByLocalWindow: buffer.blockedByLocalWindow, + BlockedByRemoteWindow: buffer.blockedByRemoteWindow, + RetxScale: buffer.retxScale, + RetxThreshold: buffer.retxThreshold, + TimeSinceLastRetx: timeSinceLastRetransmit.String(), + CloseWhenEmpty: buffer.closeWhenEmpty.Load(), + } + return result +} + +func (buffer *LinkSendBuffer) Inspect() *SendBufferDetail { + timeout := time.After(100 * time.Millisecond) + inspectEvent := &sendBufferInspectEvent{ + notifyComplete: make(chan *SendBufferDetail, 1), + } + + select { + case buffer.inspectRequests <- inspectEvent: + select { + case result := <-inspectEvent.notifyComplete: + result.AcquiredSafely = true + return result + case <-timeout: + } + case <-timeout: + } + + result := buffer.inspect() + result.AcquiredSafely = false + return result +} + +type sendBufferInspectEvent struct { + notifyComplete chan *SendBufferDetail +} + +func (self *sendBufferInspectEvent) handle(buffer *LinkSendBuffer) { + result := buffer.inspect() + self.notifyComplete <- result +} diff --git a/xgress/messages.go b/xgress/messages.go new file mode 100644 index 00000000..01b1cded --- /dev/null +++ b/xgress/messages.go @@ -0,0 +1,457 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +import ( + "encoding/binary" + "fmt" + "github.com/openziti/channel/v4" + "github.com/openziti/foundation/v2/info" + "github.com/openziti/foundation/v2/uuidz" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "math" +) + +const ( + MinHeaderKey = 2000 + MaxHeaderKey = MinHeaderKey + int32(math.MaxUint8) + + HeaderKeyCircuitId = 2256 + HeaderKeySequence = 2257 + HeaderKeyFlags = 2258 + HeaderKeyRecvBufferSize = 2259 + HeaderKeyRTT = 2260 + HeaderPayloadRaw = 2261 + + ContentTypePayloadType = 1100 + ContentTypeAcknowledgementType = 1101 + ContentTypeControlType = 1102 +) + +var ContentTypeValue = map[string]int32{ + "PayloadType": ContentTypePayloadType, + "AcknowledgementType": ContentTypeAcknowledgementType, + "ControlType": ContentTypeControlType, +} + +type Originator int32 + +const ( + Initiator Originator = 0 + Terminator Originator = 1 +) + +func (o Originator) String() string { + if o == Initiator { + return "Initiator" + } + return "Terminator" +} + +type Flag uint32 + +const ( + PayloadFlagCircuitEnd Flag = 1 + PayloadFlagOriginator Flag = 2 + PayloadFlagCircuitStart Flag = 4 + PayloadFlagChunk Flag = 8 +) + +func NewAcknowledgement(circuitId string, originator Originator) *Acknowledgement { + return &Acknowledgement{ + CircuitId: circuitId, + Flags: SetOriginatorFlag(0, originator), + } +} + +type Acknowledgement struct { + CircuitId string + Flags uint32 + RecvBufferSize uint32 + RTT uint16 + Sequence []int32 +} + +func (ack *Acknowledgement) GetCircuitId() string { + return ack.CircuitId +} + +func (ack *Acknowledgement) GetFlags() uint32 { + return ack.Flags +} + +func (ack *Acknowledgement) GetOriginator() Originator { + if isFlagSet(ack.Flags, PayloadFlagOriginator) { + return Terminator + } + return Initiator +} + +func (ack *Acknowledgement) GetSequence() []int32 { + return ack.Sequence +} + +func (ack *Acknowledgement) marshallSequence() []byte { + if len(ack.Sequence) == 0 { + return nil + } + buf := make([]byte, len(ack.Sequence)*4) + nextWriteBuf := buf + for _, seq := range ack.Sequence { + binary.BigEndian.PutUint32(nextWriteBuf, uint32(seq)) + nextWriteBuf = nextWriteBuf[4:] + } + return buf +} + +func (ack *Acknowledgement) unmarshallSequence(data []byte) error { + if len(data) == 0 { + return nil + } + + if len(data)%4 != 0 { + return fmt.Errorf("received sequence with wrong number of bytes: %v", len(data)) + } + ack.Sequence = make([]int32, len(data)/4) + + nextReadBuf := data + for i := range ack.Sequence { + ack.Sequence[i] = int32(binary.BigEndian.Uint32(nextReadBuf)) + nextReadBuf = nextReadBuf[4:] + } + return nil +} + +func (ack *Acknowledgement) Marshall() *channel.Message { + msg := channel.NewMessage(ContentTypeAcknowledgementType, ack.marshallSequence()) + msg.PutUint16Header(HeaderKeyRTT, ack.RTT) + msg.Headers[HeaderKeyCircuitId] = []byte(ack.CircuitId) + if ack.Flags != 0 { + msg.PutUint32Header(HeaderKeyFlags, ack.Flags) + } + msg.PutUint32Header(HeaderKeyRecvBufferSize, ack.RecvBufferSize) + return msg +} + +func UnmarshallAcknowledgement(msg *channel.Message) (*Acknowledgement, error) { + ack := &Acknowledgement{} + + circuitId, ok := msg.Headers[HeaderKeyCircuitId] + if !ok { + return nil, fmt.Errorf("no circuitId found in xgress payload message") + } + + // If no flags are present, it just means no flags have been set + flags, _ := msg.GetUint32Header(HeaderKeyFlags) + + ack.CircuitId = string(circuitId) + ack.Flags = flags + if ack.RecvBufferSize, ok = msg.GetUint32Header(HeaderKeyRecvBufferSize); !ok { + ack.RecvBufferSize = math.MaxUint32 + } + + ack.RTT, _ = msg.GetUint16Header(HeaderKeyRTT) + + if err := ack.unmarshallSequence(msg.Body); err != nil { + return nil, err + } + + return ack, nil +} + +func (ack *Acknowledgement) GetLoggerFields() logrus.Fields { + return logrus.Fields{ + "circuitId": ack.CircuitId, + "linkRecvBufferSize": ack.RecvBufferSize, + "seq": fmt.Sprintf("%+v", ack.Sequence), + "RTT": ack.RTT, + } +} + +type PayloadType byte + +const ( + PayloadTypeXg PayloadType = 1 + PayloadTypeRtx PayloadType = 2 + PayloadTypeFwd PayloadType = 3 +) + +type Payload struct { + CircuitId string + Flags uint32 + RTT uint16 + Sequence int32 + Headers map[uint8][]byte + Data []byte + raw []byte +} + +func (payload *Payload) GetSequence() int32 { + return payload.Sequence +} + +func (payload *Payload) Marshall() *channel.Message { + if payload.raw != nil { + if payload.raw[0]&RttFlagMask != 0 { + rtt := uint16(info.NowInMilliseconds()) + b0 := byte(rtt) + b1 := byte(rtt >> 8) + payload.raw[2] = b0 + payload.raw[3] = b1 + } + return channel.NewMessage(channel.ContentTypeRaw, payload.raw) + } + + msg := channel.NewMessage(ContentTypePayloadType, payload.Data) + addPayloadHeadersToMsg(msg, payload.Headers) + msg.Headers[HeaderKeyCircuitId] = []byte(payload.CircuitId) + if payload.Flags != 0 { + msg.PutUint32Header(HeaderKeyFlags, payload.Flags) + } + + msg.PutUint64Header(HeaderKeySequence, uint64(payload.Sequence)) + msg.PutUint16Header(HeaderKeyRTT, uint16(info.NowInMilliseconds())) + + return msg +} + +func addPayloadHeadersToMsg(msg *channel.Message, headers map[uint8][]byte) { + for key, value := range headers { + msgHeaderKey := MinHeaderKey + int32(key) + msg.Headers[msgHeaderKey] = value + } +} + +func UnmarshallPayload(msg *channel.Message) (*Payload, error) { + var headers map[uint8][]byte + for key, val := range msg.Headers { + if key >= MinHeaderKey && key <= MaxHeaderKey { + if headers == nil { + headers = make(map[uint8][]byte) + } + xgressHeaderKey := uint8(key - MinHeaderKey) + headers[xgressHeaderKey] = val + } + } + + payload := &Payload{ + Headers: headers, + Data: msg.Body, + } + + circuitId, ok := msg.Headers[HeaderKeyCircuitId] + if !ok { + return nil, fmt.Errorf("no circuitId found in xgress payload message") + } + + // If no flags are present, it just means no flags have been set + flags, _ := msg.GetUint32Header(HeaderKeyFlags) + + payload.CircuitId = string(circuitId) + payload.Flags = flags + + payload.RTT, _ = msg.GetUint16Header(HeaderKeyRTT) + + sequence, ok := msg.GetUint64Header(HeaderKeySequence) + if !ok { + return nil, fmt.Errorf("no sequence found in xgress payload message") + } + payload.Sequence = int32(sequence) + + if raw, ok := msg.Headers[HeaderPayloadRaw]; ok { + payload.raw = raw + } + + return payload, nil +} + +func isFlagSet(flags uint32, flag Flag) bool { + return Flag(flags)&flag == flag +} + +func setPayloadFlag(flags uint32, flag Flag) uint32 { + return uint32(Flag(flags) | flag) +} + +func (payload *Payload) GetCircuitId() string { + return payload.CircuitId +} + +func (payload *Payload) GetFlags() uint32 { + return payload.Flags +} + +func (payload *Payload) IsCircuitEndFlagSet() bool { + return isFlagSet(payload.Flags, PayloadFlagCircuitEnd) +} + +func (payload *Payload) IsCircuitStartFlagSet() bool { + return isFlagSet(payload.Flags, PayloadFlagCircuitStart) +} + +func (payload *Payload) GetOriginator() Originator { + if isFlagSet(payload.Flags, PayloadFlagOriginator) { + return Terminator + } + return Initiator +} + +func SetOriginatorFlag(flags uint32, originator Originator) uint32 { + if originator == Initiator { + return ^uint32(PayloadFlagOriginator) & flags + } + return uint32(PayloadFlagOriginator) | flags +} + +func (payload *Payload) GetLoggerFields() logrus.Fields { + result := logrus.Fields{ + "circuitId": payload.CircuitId, + "seq": payload.Sequence, + "origin": payload.GetOriginator(), + } + + if uuidVal, found := payload.Headers[HeaderKeyUUID]; found { + result["uuid"] = uuidz.ToString(uuidVal) + } + + return result +} + +type ControlType byte + +func (self ControlType) String() string { + switch self { + case ControlTypeTraceRoute: + return "traceroute" + case ControlTypeTraceRouteResponse: + return "traceroute_response" + default: + return fmt.Sprintf("unhandled: %v", byte(self)) + } +} + +const ( + ControlTypeTraceRoute ControlType = 1 + ControlTypeTraceRouteResponse ControlType = 2 +) + +const ( + ControlHopCount = 20 + ControlHopType = 21 + ControlHopId = 22 + ControlTimestamp = 23 + ControlUserVal = 24 + ControlError = 25 +) + +type Control struct { + Type ControlType + CircuitId string + Headers channel.Headers +} + +func (self *Control) Marshall() *channel.Message { + msg := channel.NewMessage(ContentTypeControlType, append([]byte{byte(self.Type)}, self.CircuitId...)) + msg.Headers = self.Headers + return msg +} + +func UnmarshallControl(msg *channel.Message) (*Control, error) { + if len(msg.Body) < 2 { + return nil, errors.New("control message body too short") + } + return &Control{ + Type: ControlType(msg.Body[0]), + CircuitId: string(msg.Body[1:]), + Headers: msg.Headers, + }, nil +} + +func (self *Control) IsTypeTraceRoute() bool { + return self.Type == ControlTypeTraceRoute +} + +func (self *Control) IsTypeTraceRouteResponse() bool { + return self.Type == ControlTypeTraceRouteResponse +} + +func (self *Control) DecrementAndGetHop() uint32 { + hop, _ := self.Headers.GetUint32Header(ControlHopCount) + if hop == 0 { + return 0 + } + hop-- + self.Headers.PutUint32Header(ControlHopCount, hop) + return hop +} + +func (self *Control) CreateTraceResponse(hopType, hopId string) *Control { + resp := &Control{ + Type: ControlTypeTraceRouteResponse, + CircuitId: self.CircuitId, + Headers: self.Headers, + } + resp.Headers.PutStringHeader(ControlHopType, hopType) + resp.Headers.PutStringHeader(ControlHopId, hopId) + return resp +} + +func (self *Control) GetLoggerFields() logrus.Fields { + result := logrus.Fields{ + "circuitId": self.CircuitId, + "type": self.Type, + } + + if uuidVal, found := self.Headers[HeaderKeyUUID]; found { + result["uuid"] = uuidz.ToString(uuidVal) + } + + return result +} + +func RespondToTraceRequest(headers channel.Headers, hopType, hopId string, response ControlReceiver) { + resp := &Control{Headers: headers} + resp.DecrementAndGetHop() + resp.Headers.PutStringHeader(ControlHopType, hopType) + resp.Headers.PutStringHeader(ControlHopId, hopId) + response.HandleControlReceive(ControlTypeTraceRouteResponse, headers) +} + +type InvalidTerminatorError struct { + InnerError error +} + +func (e InvalidTerminatorError) Error() string { + return e.InnerError.Error() +} + +func (e InvalidTerminatorError) Unwrap() error { + return e.InnerError +} + +type MisconfiguredTerminatorError struct { + InnerError error +} + +func (e MisconfiguredTerminatorError) Error() string { + return e.InnerError.Error() +} + +func (e MisconfiguredTerminatorError) Unwrap() error { + return e.InnerError +} diff --git a/xgress/messages_test.go b/xgress/messages_test.go new file mode 100644 index 00000000..97114ffc --- /dev/null +++ b/xgress/messages_test.go @@ -0,0 +1,121 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +import ( + "github.com/stretchr/testify/assert" + "reflect" + "sync/atomic" + "testing" +) + +// A simple test to check for failure of alignment on atomic operations for 64 bit variables in a struct +func Test64BitAlignment(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("One of the variables that was tested is not properly 64-bit aligned.") + } + }() + + lsb := Xgress{} + tPayload := txPayload{} + reTx := Retransmitter{} + + atomic.LoadInt64(&lsb.timeOfLastRxFromLink) + atomic.LoadInt64(&tPayload.age) + atomic.LoadInt64(&reTx.retransmitsQueueSize) +} + +func TestSetOriginatorFlag(t *testing.T) { + type args struct { + flags uint32 + originator Originator + } + tests := []struct { + name string + args args + want uint32 + }{ + + {name: "set empty to ingress", + args: args{ + flags: 0, + originator: Initiator, + }, + want: 0, + }, + {name: "set end of circuit to ingress", + args: args{ + flags: uint32(PayloadFlagCircuitEnd), + originator: Initiator, + }, + want: uint32(PayloadFlagCircuitEnd), + }, + {name: "set empty to egress", + args: args{ + flags: 0, + originator: Terminator, + }, + want: uint32(PayloadFlagOriginator), + }, + {name: "set end of circuit to egress", + args: args{ + flags: uint32(PayloadFlagCircuitEnd), + originator: Terminator, + }, + want: uint32(PayloadFlagCircuitEnd) | uint32(PayloadFlagOriginator), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := SetOriginatorFlag(tt.args.flags, tt.args.originator); got != tt.want { + t.Errorf("SetOriginatorFlag() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAcknowledgement_marshallSequence(t *testing.T) { + tests := []struct { + name string + sequence []int32 + }{ + + {name: "nil", sequence: nil}, + {name: "empty", sequence: make([]int32, 0)}, + {name: "one entry", sequence: []int32{1}}, + {name: "many entries", sequence: []int32{1, -1, 100, 200, -3213232, 421123, -58903204, -4324, 432432, 0, 9}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ack := &Acknowledgement{ + Sequence: tt.sequence, + } + got := ack.marshallSequence() + ack2 := &Acknowledgement{} + err := ack2.unmarshallSequence(got) + assert.NoError(t, err) + + if len(ack.Sequence) == 0 { + return + } + if !reflect.DeepEqual(ack, ack2) { + t.Errorf("marshallSequence() = %v, want %v", ack2, ack) + } + }) + } +} diff --git a/xgress/metrics.go b/xgress/metrics.go new file mode 100644 index 00000000..0f29607f --- /dev/null +++ b/xgress/metrics.go @@ -0,0 +1,126 @@ +package xgress + +import ( + "github.com/openziti/metrics" + "sync/atomic" + "time" +) + +type Metrics interface { + MarkAckReceived() + MarkPayloadDropped() + MarkDuplicateAck() + MarkDuplicatePayload() + BufferBlockedByLocalWindow() + BufferUnblockedByLocalWindow() + BufferBlockedByRemoteWindow() + BufferUnblockedByRemoteWindow() + + PayloadWritten(duration time.Duration) + BufferUnblocked(duration time.Duration) + + SendPayloadBuffered(payloadSize int64) + SendPayloadDelivered(payloadSize int64) +} + +type metricsImpl struct { + ackRxMeter metrics.Meter + droppedPayloadsMeter metrics.Meter + + payloadWriteTimer metrics.Timer + duplicateAcksMeter metrics.Meter + duplicatePayloadsMeter metrics.Meter + + buffersBlockedByLocalWindow int64 + buffersBlockedByRemoteWindow int64 + outstandingPayloads int64 + outstandingPayloadBytes int64 + + buffersBlockedByLocalWindowMeter metrics.Meter + buffersBlockedByRemoteWindowMeter metrics.Meter + + bufferBlockedTime metrics.Timer +} + +func (self *metricsImpl) SendPayloadBuffered(payloadSize int64) { + atomic.AddInt64(&self.outstandingPayloads, 1) + atomic.AddInt64(&self.outstandingPayloadBytes, payloadSize) +} + +func (self *metricsImpl) SendPayloadDelivered(payloadSize int64) { + atomic.AddInt64(&self.outstandingPayloads, -1) + atomic.AddInt64(&self.outstandingPayloadBytes, -payloadSize) +} + +func (self *metricsImpl) MarkAckReceived() { + self.ackRxMeter.Mark(1) +} + +func (self *metricsImpl) MarkPayloadDropped() { + self.droppedPayloadsMeter.Mark(1) +} + +func (self *metricsImpl) MarkDuplicateAck() { + self.duplicateAcksMeter.Mark(1) +} + +func (self *metricsImpl) MarkDuplicatePayload() { + self.duplicatePayloadsMeter.Mark(1) +} + +func (self *metricsImpl) BufferBlockedByLocalWindow() { + atomic.AddInt64(&self.buffersBlockedByLocalWindow, 1) + self.buffersBlockedByLocalWindowMeter.Mark(1) +} + +func (self *metricsImpl) BufferUnblockedByLocalWindow() { + atomic.AddInt64(&self.buffersBlockedByLocalWindow, -1) +} + +func (self *metricsImpl) BufferBlockedByRemoteWindow() { + atomic.AddInt64(&self.buffersBlockedByRemoteWindow, 1) + self.buffersBlockedByRemoteWindowMeter.Mark(1) +} + +func (self *metricsImpl) BufferUnblockedByRemoteWindow() { + atomic.AddInt64(&self.buffersBlockedByRemoteWindow, -1) +} + +func (self *metricsImpl) PayloadWritten(duration time.Duration) { + self.payloadWriteTimer.Update(duration) +} + +func (self *metricsImpl) BufferUnblocked(duration time.Duration) { + self.bufferBlockedTime.Update(duration) +} + +func NewMetrics(registry metrics.Registry) Metrics { + impl := &metricsImpl{ + droppedPayloadsMeter: registry.Meter("xgress.dropped_payloads"), + ackRxMeter: registry.Meter("xgress.rx.acks"), + payloadWriteTimer: registry.Timer("xgress.tx_write_time"), + duplicateAcksMeter: registry.Meter("xgress.ack_duplicates"), + duplicatePayloadsMeter: registry.Meter("xgress.payload_duplicates"), + buffersBlockedByLocalWindowMeter: registry.Meter("xgress.blocked_by_local_window_rate"), + buffersBlockedByRemoteWindowMeter: registry.Meter("xgress.blocked_by_remote_window_rate"), + bufferBlockedTime: registry.Timer("xgress.blocked_time"), + } + + registry.FuncGauge("xgress.blocked_by_local_window", func() int64 { + return atomic.LoadInt64(&impl.buffersBlockedByLocalWindow) + }) + + registry.FuncGauge("xgress.blocked_by_remote_window", func() int64 { + return atomic.LoadInt64(&impl.buffersBlockedByRemoteWindow) + }) + + registry.FuncGauge("xgress.tx_unacked_payloads", func() int64 { + return atomic.LoadInt64(&impl.outstandingPayloads) + }) + + registry.FuncGauge("xgress.tx_unacked_payload_bytes", func() int64 { + return atomic.LoadInt64(&impl.outstandingPayloadBytes) + }) + + return impl +} diff --git a/xgress/minimal_payload_test.go b/xgress/minimal_payload_test.go new file mode 100644 index 00000000..3ee25630 --- /dev/null +++ b/xgress/minimal_payload_test.go @@ -0,0 +1,389 @@ +package xgress + +import ( + "encoding/binary" + "errors" + "fmt" + "github.com/michaelquigley/pfxlog" + "github.com/openziti/channel/v4" + "github.com/openziti/metrics" + cmap "github.com/orcaman/concurrent-map/v2" + metrics2 "github.com/rcrowley/go-metrics" + "github.com/sirupsen/logrus" + "io" + "testing" + "time" +) + +func newTestXgConn(bufferSize int, targetSends uint32, targetReceives uint32) *testXgConn { + return &testXgConn{ + bufferSize: bufferSize, + targetSends: targetSends, + targetReceives: targetReceives, + done: make(chan struct{}), + errs: make(chan error, 1), + } +} + +type testXgConn struct { + sndMsgCounter uint32 + rcvMsgCounter uint32 + bufferSize int + targetSends uint32 + targetReceives uint32 + sendCounter uint32 + recvCounter uint32 + done chan struct{} + errs chan error + bufCounter uint32 +} + +func (self *testXgConn) Close() error { + return nil +} + +func (self *testXgConn) LogContext() string { + return "test" +} + +func (self *testXgConn) ReadPayload() ([]byte, map[uint8][]byte, error) { + self.sndMsgCounter++ + if self.targetSends == 0 { + time.Sleep(time.Minute) + } + var m map[uint8][]byte + buf := make([]byte, self.bufferSize) + sl := buf + for len(sl) > 0 && self.sendCounter < self.targetSends { + binary.BigEndian.PutUint32(sl, self.sendCounter) + self.sendCounter++ + sl = sl[4:] + } + + if len(sl) > 0 { + buf = buf[:len(buf)-len(sl)] + } + + if self.sndMsgCounter%10 == 0 { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, self.sndMsgCounter) + m = map[uint8][]byte{ + 5: b, + } + if self.sndMsgCounter%20 == 0 { + m[10] = []byte("hello") + } + } + + if self.sendCounter >= self.targetSends { + //fmt.Printf("sending final %d bytes\n", len(buf)) + return buf, nil, io.EOF + } + + //fmt.Printf("sending %d bytes\n", len(buf)) + + return buf, m, nil +} + +func (self *testXgConn) WritePayload(buf []byte, m map[uint8][]byte) (int, error) { + self.rcvMsgCounter++ + sl := buf + for len(sl) > 0 { + next := binary.BigEndian.Uint32(sl) + sl = sl[4:] + if next != self.recvCounter { + select { + case self.errs <- fmt.Errorf("expected counter %d, got %d, buf: %d", self.recvCounter, next, self.bufCounter): + default: + } + } + self.recvCounter++ + if self.recvCounter == self.targetReceives { + close(self.done) + } else if self.recvCounter > self.targetReceives { + select { + case self.errs <- fmt.Errorf("exceeded expected counter %d, got %d, buf: %d", self.targetReceives, self.recvCounter, self.bufCounter): + default: + } + } + } + + if self.rcvMsgCounter%10 == 0 { + b, ok := m[5] + if !ok { + select { + case self.errs <- fmt.Errorf("expected header 5, got %+v headers, rcv count: %d", m, self.rcvMsgCounter): + default: + } + } else if len(b) != 4 { + select { + case self.errs <- fmt.Errorf("expected header 5, len 4, got %+v, rcv count: %d", b, self.rcvMsgCounter): + default: + } + } else { + v := binary.BigEndian.Uint32(b) + if v != self.rcvMsgCounter { + select { + case self.errs <- fmt.Errorf("expected header counter %d, got %d", self.rcvMsgCounter, v): + default: + } + } + } + if self.rcvMsgCounter%20 == 0 { + if string(m[10]) != "hello" { + select { + case self.errs <- fmt.Errorf("missing 10:hello in map, counter %d", self.recvCounter): + default: + } + } + } + } + + //fmt.Printf("received %d bytes\n", len(buf)) + self.bufCounter++ + + return len(buf), nil +} + +func (self *testXgConn) HandleControlMsg(ControlType, channel.Headers, ControlReceiver) error { + panic("implement me") +} + +type testIntermediary struct { + acker AckSender + rtx *Retransmitter + payloadIngester *PayloadIngester + circuitId string + dest *Xgress + msgs channel.MessageStrategy + payloadTransformer PayloadTransformer + counter uint64 + bytesCallback func([]byte) +} + +func (self *testIntermediary) GetRetransmitter() *Retransmitter { + return self.rtx +} + +func (self *testIntermediary) GetPayloadIngester() *PayloadIngester { + return self.payloadIngester +} + +func (self *testIntermediary) GetMetrics() Metrics { + return noopMetrics{} +} + +func (self *testIntermediary) ForwardAcknowledgement(ack *Acknowledgement, address Address) { + self.acker.SendAck(ack, address) +} + +func (self *testIntermediary) ForwardPayload(payload *Payload, x *Xgress) { + m := payload.Marshall() + self.payloadTransformer.Tx(m, nil) + b, err := self.msgs.GetMarshaller()(m) + if err != nil { + panic(err) + } + + if self.bytesCallback != nil { + self.bytesCallback(b) + } + + m, err = self.msgs.GetPacketProducer()(b) + if err != nil { + logrus.WithError(err).Error("error get next msg") + panic(err) + } + + if err = self.validateMessage(m); err != nil { + panic(err) + } + + payload, err = UnmarshallPayload(m) + if err != nil { + panic(err) + } + + if err = self.dest.SendPayload(payload, 0, PayloadTypeXg); err != nil { + panic(err) + } + //fmt.Printf("transmitted payload %d from %s -> %s\n", payload.Sequence, x.address, self.dest.address) +} + +func (self *testIntermediary) RetransmitPayload(srcAddr Address, payload *Payload) error { + //self.ForwardPayload(payload, nil) + return nil +} + +func (self *testIntermediary) validateMessage(m *channel.Message) error { + circuitId, found := m.GetStringHeader(HeaderKeyCircuitId) + if !found { + return errors.New("no circuit id found") + } + + if circuitId != self.circuitId { + return fmt.Errorf("expected circuit id %s, got %s", self.circuitId, circuitId) + } + + seq, found := m.GetUint64Header(HeaderKeySequence) + if !found { + return errors.New("no sequence found") + } + if seq != self.counter { + return fmt.Errorf("expected sequence %d, got %d", self.counter, seq) + } + self.counter++ + + return nil +} + +func (self *testIntermediary) ForwardControlMessage(control *Control, x *Xgress) { + panic("implement me") +} + +type testAcker struct { + destinations cmap.ConcurrentMap[string, *Xgress] +} + +func (self *testAcker) SendAck(ack *Acknowledgement, address Address) { + dest, _ := self.destinations.Get(string(address)) + if dest != nil { + if err := dest.SendAcknowledgement(ack); err != nil { + panic(err) + } + } else { + panic(fmt.Errorf("no xgress found with id %s", string(address))) + } +} + +type mockFaulter struct{} + +func (m mockFaulter) ReportForwardingFault(circuitId string, ctrlId string) { +} + +func Test_MinimalPayloadMarshalling(t *testing.T) { + logOptions := pfxlog.DefaultOptions().SetTrimPrefix("github.com/openziti/").NoColor() + pfxlog.GlobalInit(logrus.InfoLevel, logOptions) + pfxlog.SetFormatter(pfxlog.NewFormatter(pfxlog.DefaultOptions().SetTrimPrefix("github.com/openziti/").StartingToday())) + + metricsRegistry := metrics.NewRegistry("test", nil) + + closeNotify := make(chan struct{}) + defer func() { + close(closeNotify) + }() + + payloadIngester := NewPayloadIngester(closeNotify) + rtx := NewRetransmitter(mockFaulter{}, metricsRegistry, closeNotify) + ackHandler := &testAcker{destinations: cmap.New[*Xgress]()} + + options := DefaultOptions() + options.Mtu = 1400 + + circuitId := "circuit1" + srcTestConn := newTestXgConn(10_000, 100_000, 0) + dstTestConn := newTestXgConn(10_000, 0, 100_000) + + srcXg := NewXgress(circuitId, "ctrl", "src", srcTestConn, Initiator, options, nil) + dstXg := NewXgress(circuitId, "ctrl", "dst", dstTestConn, Terminator, options, nil) + + ackHandler.destinations.Set("src", dstXg) + ackHandler.destinations.Set("dst", srcXg) + + msgStrategy := channel.DatagramMessageStrategy(UnmarshallPacketPayload) + srcXg.dataPlane = &testIntermediary{ + acker: ackHandler, + rtx: rtx, + payloadIngester: payloadIngester, + circuitId: circuitId, + dest: dstXg, + msgs: msgStrategy, + } + + dstXg.dataPlane = &testIntermediary{ + acker: ackHandler, + rtx: rtx, + payloadIngester: payloadIngester, + circuitId: circuitId, + dest: srcXg, + msgs: msgStrategy, + } + + srcXg.Start() + dstXg.Start() + + select { + case <-dstTestConn.done: + case err := <-dstTestConn.errs: + t.Fatal(err) + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func Test_PayloadSize(t *testing.T) { + logOptions := pfxlog.DefaultOptions().SetTrimPrefix("github.com/openziti/").NoColor() + pfxlog.GlobalInit(logrus.InfoLevel, logOptions) + pfxlog.SetFormatter(pfxlog.NewFormatter(pfxlog.DefaultOptions().SetTrimPrefix("github.com/openziti/").StartingToday())) + + metricsRegistry := metrics.NewRegistry("test", nil) + + closeNotify := make(chan struct{}) + defer func() { + close(closeNotify) + }() + + payloadIngester := NewPayloadIngester(closeNotify) + rtx := NewRetransmitter(mockFaulter{}, metricsRegistry, closeNotify) + ackHandler := &testAcker{destinations: cmap.New[*Xgress]()} + + options := DefaultOptions() + //options.Mtu = 1435 + + h := metricsRegistry.Histogram("msg_size") + + circuitId := "circuit2" + srcTestConn := newTestXgConn(200, 100_000, 0) + dstTestConn := newTestXgConn(200, 0, 100_000) + + srcXg := NewXgress(circuitId, "ctrl", "src", srcTestConn, Initiator, options, nil) + dstXg := NewXgress(circuitId, "ctrl", "dst", dstTestConn, Terminator, options, nil) + + ackHandler.destinations.Set("src", dstXg) + ackHandler.destinations.Set("dst", srcXg) + + msgStrategy := channel.DatagramMessageStrategy(UnmarshallPacketPayload) + srcXg.dataPlane = &testIntermediary{ + acker: ackHandler, + rtx: rtx, + payloadIngester: payloadIngester, + circuitId: circuitId, + dest: dstXg, + msgs: msgStrategy, + bytesCallback: func(bytes []byte) { + h.Update(int64(len(bytes))) + }, + } + + dstXg.dataPlane = &testIntermediary{ + acker: ackHandler, + rtx: rtx, + payloadIngester: payloadIngester, + circuitId: circuitId, + dest: srcXg, + msgs: msgStrategy, + } + + srcXg.Start() + dstXg.Start() + + select { + case <-dstTestConn.done: + case err := <-dstTestConn.errs: + t.Fatal(err) + case <-time.After(time.Second): + t.Fatal("timeout") + } + + fmt.Printf("max msg size: %d\n", h.(metrics2.Histogram).Max()) +} diff --git a/xgress/options.go b/xgress/options.go new file mode 100644 index 00000000..93a1df49 --- /dev/null +++ b/xgress/options.go @@ -0,0 +1,177 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +import ( + "encoding/json" + "github.com/pkg/errors" + "time" +) + +// Options contains common Xgress configuration options +type Options struct { + Mtu int32 + RandomDrops bool + Drop1InN int32 + TxQueueSize int32 + + TxPortalStartSize uint32 + TxPortalMaxSize uint32 + TxPortalMinSize uint32 + TxPortalIncreaseThresh uint32 + TxPortalIncreaseScale float64 + TxPortalRetxThresh uint32 + TxPortalRetxScale float64 + TxPortalDupAckThresh uint32 + TxPortalDupAckScale float64 + + RxBufferSize uint32 + RetxStartMs uint32 + RetxScale float64 + RetxAddMs uint32 + + MaxCloseWait time.Duration + GetCircuitTimeout time.Duration + CircuitStartTimeout time.Duration + ConnectTimeout time.Duration +} + +func LoadOptions(data OptionsData) (*Options, error) { + options := DefaultOptions() + + if value, found := data["options"]; found { + data = value.(map[interface{}]interface{}) + + if value, found := data["mtu"]; found { + options.Mtu = int32(value.(int)) + } + if value, found := data["randomDrops"]; found { + options.RandomDrops = value.(bool) + } + if value, found := data["drop1InN"]; found { + options.Drop1InN = int32(value.(int)) + } + if value, found := data["txQueueSize"]; found { + options.TxQueueSize = int32(value.(int)) + } + + if value, found := data["txPortalStartSize"]; found { + options.TxPortalStartSize = uint32(value.(int)) + } + if value, found := data["txPortalMinSize"]; found { + options.TxPortalMinSize = uint32(value.(int)) + } + if value, found := data["txPortalMaxSize"]; found { + options.TxPortalMaxSize = uint32(value.(int)) + } + if value, found := data["txPortalIncreaseThresh"]; found { + options.TxPortalIncreaseThresh = uint32(value.(int)) + } + if value, found := data["txPortalIncreaseScale"]; found { + options.TxPortalIncreaseScale = value.(float64) + } + if value, found := data["txPortalRetxThresh"]; found { + options.TxPortalRetxThresh = uint32(value.(int)) + } + if value, found := data["txPortalRetxScale"]; found { + options.TxPortalRetxScale = value.(float64) + } + if value, found := data["txPortalDupAckThresh"]; found { + options.TxPortalDupAckThresh = uint32(value.(int)) + } + if value, found := data["txPortalDupAckScale"]; found { + options.TxPortalDupAckScale = value.(float64) + } + + if value, found := data["rxBufferSize"]; found { + options.RxBufferSize = uint32(value.(int)) + } + if value, found := data["retxStartMs"]; found { + options.RetxStartMs = uint32(value.(int)) + } + if value, found := data["retxScale"]; found { + options.RetxScale = value.(float64) + } + if value, found := data["retxAddMs"]; found { + options.RetxAddMs = uint32(value.(int)) + } + + if value, found := data["maxCloseWaitMs"]; found { + options.MaxCloseWait = time.Duration(value.(int)) * time.Millisecond + } + + if value, found := data["getCircuitTimeout"]; found { + getCircuitTimeout, err := time.ParseDuration(value.(string)) + if err != nil { + return nil, errors.Wrap(err, "invalid 'getCircuitTimeout' value") + } + options.GetCircuitTimeout = getCircuitTimeout + } + + if value, found := data["circuitStartTimeout"]; found { + circuitStartTimeout, err := time.ParseDuration(value.(string)) + if err != nil { + return nil, errors.Wrap(err, "invalid 'circuitStartTimeout' value") + } + options.CircuitStartTimeout = circuitStartTimeout + } + + if value, found := data["connectTimeout"]; found { + connectTimeout, err := time.ParseDuration(value.(string)) + if err != nil { + return nil, errors.Wrap(err, "invalid 'connectTimeout' value") + } + options.ConnectTimeout = connectTimeout + } + } + + return options, nil +} + +func DefaultOptions() *Options { + return &Options{ + Mtu: 0, + RandomDrops: false, + Drop1InN: 100, + TxQueueSize: 1, + TxPortalStartSize: 4 * 1024 * 1024, + TxPortalMinSize: 16 * 1024, + TxPortalMaxSize: 4 * 1024 * 1024, + TxPortalIncreaseThresh: 28, + TxPortalIncreaseScale: 1.0, + TxPortalRetxThresh: 64, + TxPortalRetxScale: 0.75, + TxPortalDupAckThresh: 64, + TxPortalDupAckScale: 0.9, + RxBufferSize: 4 * 1024 * 1024, + RetxStartMs: 200, + RetxScale: 1.5, + RetxAddMs: 0, + MaxCloseWait: 30 * time.Second, + GetCircuitTimeout: 30 * time.Second, + CircuitStartTimeout: 3 * time.Minute, + ConnectTimeout: 0, // operating system default + } +} + +func (options Options) String() string { + data, err := json.Marshal(options) + if err != nil { + return err.Error() + } + return string(data) +} diff --git a/xgress/ordering_test.go b/xgress/ordering_test.go new file mode 100644 index 00000000..42d732dd --- /dev/null +++ b/xgress/ordering_test.go @@ -0,0 +1,152 @@ +package xgress + +import ( + "encoding/binary" + "github.com/openziti/channel/v4" + "github.com/stretchr/testify/require" + "io" + "sync/atomic" + "testing" + "time" +) + +type testConn struct { + ch chan uint64 + closeNotify chan struct{} + closed atomic.Bool +} + +func (conn *testConn) Close() error { + if conn.closed.CompareAndSwap(false, true) { + close(conn.closeNotify) + } + return nil +} + +func (conn *testConn) LogContext() string { + return "test" +} + +func (conn *testConn) ReadPayload() ([]byte, map[uint8][]byte, error) { + <-conn.closeNotify + return nil, nil, io.EOF +} + +func (conn *testConn) WritePayload(bytes []byte, _ map[uint8][]byte) (int, error) { + val := binary.LittleEndian.Uint64(bytes) + conn.ch <- val + return len(bytes), nil +} + +func (conn *testConn) HandleControlMsg(ControlType, channel.Headers, ControlReceiver) error { + return nil +} + +type noopReceiveHandler struct { + payloadIngester *PayloadIngester +} + +func (n noopReceiveHandler) RetransmitPayload(srcAddr Address, payload *Payload) error { + return nil +} + +func (n noopReceiveHandler) GetMetrics() Metrics { + return noopMetrics{} +} + +func (n noopReceiveHandler) GetRetransmitter() *Retransmitter { + return nil +} + +func (n noopReceiveHandler) GetPayloadIngester() *PayloadIngester { + return n.payloadIngester +} + +func (n noopReceiveHandler) ForwardAcknowledgement(*Acknowledgement, Address) {} + +func (n noopReceiveHandler) ForwardPayload(*Payload, *Xgress) {} + +func (n noopReceiveHandler) ForwardControlMessage(*Control, *Xgress) {} + +func Test_Ordering(t *testing.T) { + closeNotify := make(chan struct{}) + + conn := &testConn{ + ch: make(chan uint64, 1), + closeNotify: make(chan struct{}), + } + + x := NewXgress("test", "ctrl", "test", conn, Initiator, DefaultOptions(), nil) + x.dataPlane = noopReceiveHandler{ + payloadIngester: NewPayloadIngester(closeNotify), + } + go x.tx() + + defer x.Close() + + msgCount := 100000 + + errorCh := make(chan error, 1) + + go func() { + for i := 0; i < msgCount; i++ { + data := make([]byte, 8) + binary.LittleEndian.PutUint64(data, uint64(i)) + payload := &Payload{ + CircuitId: "test", + Flags: SetOriginatorFlag(0, Terminator), + RTT: 0, + Sequence: int32(i), + Headers: nil, + Data: data, + } + if err := x.SendPayload(payload, 0, PayloadTypeXg); err != nil { + errorCh <- err + x.Close() + return + } + } + }() + + timeout := time.After(20 * time.Second) + + req := require.New(t) + for i := 0; i < msgCount; i++ { + select { + case next := <-conn.ch: + req.Equal(uint64(i), next) + case <-conn.closeNotify: + req.Fail("test failed with count at %v", i) + case err := <-errorCh: + req.NoError(err) + case <-timeout: + req.Failf("timed out", "count at %v", i) + } + } +} + +type noopMetrics struct{} + +func (n noopMetrics) MarkAckReceived() {} + +func (n noopMetrics) MarkPayloadDropped() {} + +func (n noopMetrics) MarkDuplicateAck() {} + +func (n noopMetrics) MarkDuplicatePayload() {} + +func (n noopMetrics) BufferBlockedByLocalWindow() {} + +func (n noopMetrics) BufferUnblockedByLocalWindow() {} + +func (n noopMetrics) BufferBlockedByRemoteWindow() {} + +func (n noopMetrics) BufferUnblockedByRemoteWindow() {} + +func (n noopMetrics) PayloadWritten(time.Duration) {} + +func (n noopMetrics) BufferUnblocked(time.Duration) {} + +func (n noopMetrics) SendPayloadBuffered(int64) {} + +func (n noopMetrics) SendPayloadDelivered(int64) {} diff --git a/xgress/payload_ingester.go b/xgress/payload_ingester.go new file mode 100644 index 00000000..cfc852fd --- /dev/null +++ b/xgress/payload_ingester.go @@ -0,0 +1,60 @@ +package xgress + +import "time" + +type payloadEntry struct { + payload *Payload + x *Xgress +} + +type PayloadIngester struct { + payloadIngest chan *payloadEntry + payloadSendReq chan *Xgress + receiveBufferInspects chan *receiveBufferInspectEvent + closeNotify <-chan struct{} +} + +func NewPayloadIngester(closeNotify <-chan struct{}) *PayloadIngester { + pi := &PayloadIngester{ + payloadIngest: make(chan *payloadEntry, 16), + payloadSendReq: make(chan *Xgress, 16), + receiveBufferInspects: make(chan *receiveBufferInspectEvent, 4), + closeNotify: closeNotify, + } + + go pi.run() + + return pi +} + +func (self *PayloadIngester) inspect(evt *receiveBufferInspectEvent, timeout <-chan time.Time) bool { + select { + case self.receiveBufferInspects <- evt: + return true + case <-self.closeNotify: + case <-timeout: + } + return false +} + +func (self *PayloadIngester) ingest(payload *Payload, x *Xgress) { + self.payloadIngest <- &payloadEntry{ + payload: payload, + x: x, + } +} + +func (self *PayloadIngester) run() { + for { + select { + case payloadEntry := <-self.payloadIngest: + payloadEntry.x.payloadIngester(payloadEntry.payload) + case x := <-self.payloadSendReq: + x.queueSends() + case evt := <-self.receiveBufferInspects: + evt.handle() + case <-self.closeNotify: + return + } + } +} diff --git a/xgress/retransmitter.go b/xgress/retransmitter.go new file mode 100644 index 00000000..fa308951 --- /dev/null +++ b/xgress/retransmitter.go @@ -0,0 +1,165 @@ +package xgress + +import ( + "github.com/michaelquigley/pfxlog" + "github.com/openziti/metrics" + "sync/atomic" +) + +type RetransmitterFaultReporter interface { + ReportForwardingFault(circuitId string, ctrlId string) +} + +type Retransmitter struct { + faultReporter RetransmitterFaultReporter + retxTail *txPayload + retxHead *txPayload + retransmitIngest chan *txPayload + retransmitSend chan *txPayload + retransmitsQueueSize int64 + closeNotify <-chan struct{} + + retransmissions metrics.Meter + retransmissionFailures metrics.Meter +} + +func NewRetransmitter(faultReporter RetransmitterFaultReporter, metrics metrics.Registry, closeNotify <-chan struct{}) *Retransmitter { + ctrl := &Retransmitter{ + retransmitIngest: make(chan *txPayload, 16), + retransmitSend: make(chan *txPayload, 1), + closeNotify: closeNotify, + faultReporter: faultReporter, + + retransmissions: metrics.Meter("xgress.retransmissions"), + retransmissionFailures: metrics.Meter("xgress.retransmission_failures"), + } + + go ctrl.retransmitIngester() + go ctrl.retransmitSender() + + metrics.FuncGauge("xgress.retransmits.queue_size", func() int64 { + return atomic.LoadInt64(&ctrl.retransmitsQueueSize) + }) + + return ctrl +} + +func (self *Retransmitter) queue(p *txPayload) { + self.retransmitIngest <- p +} + +func (self *Retransmitter) popHead() *txPayload { + if self.retxHead == nil { + return nil + } + + result := self.retxHead + if result.prev == nil { + self.retxHead = nil + self.retxTail = nil + } else { + self.retxHead = result.prev + result.prev.next = nil + } + + result.prev = nil + result.next = nil + + atomic.AddInt64(&self.retransmitsQueueSize, -1) + + return result +} + +func (self *Retransmitter) pushTail(txp *txPayload) { + if txp.prev != nil || txp.next != nil || txp == self.retxHead { + return + } + if self.retxHead == nil { + self.retxTail = txp + self.retxHead = txp + } else { + txp.next = self.retxTail + self.retxTail.prev = txp + self.retxTail = txp + } + atomic.AddInt64(&self.retransmitsQueueSize, 1) +} + +func (self *Retransmitter) delete(txp *txPayload) { + if self.retxHead == txp { + self.popHead() + } else if txp == self.retxTail { + self.retxTail = txp.next + self.retxTail.prev = nil + atomic.AddInt64(&self.retransmitsQueueSize, -1) + } else if txp.prev != nil { + txp.prev.next = txp.next + txp.next.prev = txp.prev + atomic.AddInt64(&self.retransmitsQueueSize, -1) + } + + txp.prev = nil + txp.next = nil +} + +func (self *Retransmitter) retransmitIngester() { + var next *txPayload + for { + if next == nil { + next = self.popHead() + } + + if next == nil { + select { + case retransmit := <-self.retransmitIngest: + self.acceptRetransmit(retransmit) + case <-self.closeNotify: + return + } + } else { + select { + case retransmit := <-self.retransmitIngest: + self.acceptRetransmit(retransmit) + case self.retransmitSend <- next: + next = nil + case <-self.closeNotify: + return + } + } + } +} + +func (self *Retransmitter) acceptRetransmit(txp *txPayload) { + if txp.isAcked() { + self.delete(txp) + } else { + self.pushTail(txp) + } +} + +func (self *Retransmitter) retransmitSender() { + logger := pfxlog.Logger() + for { + select { + case retransmit := <-self.retransmitSend: + if !retransmit.isAcked() { + if err := retransmit.x.dataPlane.RetransmitPayload(retransmit.x.address, retransmit.payload); err != nil { + // if xgress is closed, don't log the error. We still want to try retransmitting in case we're re-sending end of circuit + if !retransmit.x.Closed() { + logger.WithError(err).Errorf("unexpected error while retransmitting payload from [@/%v]", retransmit.x.address) + self.retransmissionFailures.Mark(1) + self.faultReporter.ReportForwardingFault(retransmit.payload.CircuitId, retransmit.x.ctrlId) + } else { + logger.WithError(err).Tracef("unexpected error while retransmitting payload from [@/%v] (already closed)", retransmit.x.address) + } + } else { + retransmit.markSent() + self.retransmissions.Mark(1) + } + retransmit.dequeued() + } + case <-self.closeNotify: + return + } + } +} diff --git a/xgress/xgress.go b/xgress/xgress.go new file mode 100644 index 00000000..050247ad --- /dev/null +++ b/xgress/xgress.go @@ -0,0 +1,1035 @@ +/* + Copyright NetFoundry Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package xgress + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math/rand" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/michaelquigley/pfxlog" + "github.com/openziti/channel/v4" + "github.com/openziti/foundation/v2/concurrenz" + "github.com/openziti/foundation/v2/debugz" + "github.com/openziti/foundation/v2/info" + "github.com/sirupsen/logrus" +) + +const ( + HeaderKeyUUID = 0 + + closedFlag = 0 + rxerStartedFlag = 1 + endOfCircuitRecvdFlag = 2 + endOfCircuitSentFlag = 3 +) + +type Address string + +type AckSender interface { + SendAck(ack *Acknowledgement, address Address) +} + +type OptionsData map[interface{}]interface{} + +// The BindHandlers are invoked to install the appropriate handlers. +type BindHandler interface { + HandleXgressBind(x *Xgress) +} + +type ControlReceiver interface { + HandleControlReceive(controlType ControlType, headers channel.Headers) +} + +// DataPlaneAdapter is invoked by an xgress whenever messages need to be sent to the data plane. Generally a DataPlaneAdapter +// is implemented to connect the xgress to a data plane data transmission system. +type DataPlaneAdapter interface { + // ForwardPayload is used to forward data payloads onto the data-plane from an xgress + ForwardPayload(payload *Payload, x *Xgress) + + // RetransmitPayload is used to retransmit data payloads onto the data-plane from an xgress + RetransmitPayload(srcAddr Address, payload *Payload) error + + // ForwardControlMessage is used to forward control messages onto the data-plane from an xgress + ForwardControlMessage(control *Control, x *Xgress) + + // ForwardAcknowledgement is used to forward acks onto the data-plane from an xgress + ForwardAcknowledgement(ack *Acknowledgement, address Address) + GetRetransmitter() *Retransmitter + GetPayloadIngester() *PayloadIngester + GetMetrics() Metrics +} + +// CloseHandler is invoked by an xgress when the connected peer terminates the communication. +type CloseHandler interface { + // HandleXgressClose is invoked when the connected peer terminates the communication. + // + HandleXgressClose(x *Xgress) +} + +// CloseHandlerF is the function version of CloseHandler +type CloseHandlerF func(x *Xgress) + +func (self CloseHandlerF) HandleXgressClose(x *Xgress) { + self(x) +} + +// PeekHandler allows registering watcher to react to data flowing an xgress instance +type PeekHandler interface { + Rx(x *Xgress, payload *Payload) + Tx(x *Xgress, payload *Payload) + Close(x *Xgress) +} + +type Connection interface { + io.Closer + LogContext() string + ReadPayload() ([]byte, map[uint8][]byte, error) + WritePayload([]byte, map[uint8][]byte) (int, error) + HandleControlMsg(controlType ControlType, headers channel.Headers, responder ControlReceiver) error +} + +type Xgress struct { + dataPlane DataPlaneAdapter + circuitId string + ctrlId string + address Address + peer Connection + originator Originator + Options *Options + txQueue chan *Payload + closeNotify chan struct{} + rxSequence uint64 + rxSequenceLock sync.Mutex + payloadBuffer *LinkSendBuffer + linkRxBuffer *LinkReceiveBuffer + closeHandlers []CloseHandler + peekHandlers []PeekHandler + flags concurrenz.AtomicBitSet + timeOfLastRxFromLink int64 + tags map[string]string +} + +func (self *Xgress) GetIntervalId() string { + return self.circuitId +} + +func (self *Xgress) GetTags() map[string]string { + return self.tags +} + +func NewXgress(circuitId string, ctrlId string, address Address, peer Connection, originator Originator, options *Options, tags map[string]string) *Xgress { + result := &Xgress{ + circuitId: circuitId, + ctrlId: ctrlId, + address: address, + peer: peer, + originator: originator, + Options: options, + txQueue: make(chan *Payload, options.TxQueueSize), + closeNotify: make(chan struct{}), + rxSequence: 0, + linkRxBuffer: NewLinkReceiveBuffer(), + timeOfLastRxFromLink: time.Now().UnixMilli(), + tags: tags, + } + result.payloadBuffer = NewLinkSendBuffer(result) + return result +} + +func (self *Xgress) GetTimeOfLastRxFromLink() int64 { + return atomic.LoadInt64(&self.timeOfLastRxFromLink) +} + +func (self *Xgress) CircuitId() string { + return self.circuitId +} + +func (self *Xgress) CtrlId() string { + return self.ctrlId +} + +func (self *Xgress) Address() Address { + return self.address +} + +func (self *Xgress) Originator() Originator { + return self.originator +} + +func (self *Xgress) IsTerminator() bool { + return self.originator == Terminator +} + +func (self *Xgress) SetDataPlaneAdapter(dataPlaneAdapter DataPlaneAdapter) { + self.dataPlane = dataPlaneAdapter +} + +func (self *Xgress) AddCloseHandler(closeHandler CloseHandler) { + self.closeHandlers = append(self.closeHandlers, closeHandler) +} + +func (self *Xgress) AddPeekHandler(peekHandler PeekHandler) { + self.peekHandlers = append(self.peekHandlers, peekHandler) +} + +func (self *Xgress) IsEndOfCircuitReceived() bool { + return self.flags.IsSet(endOfCircuitRecvdFlag) +} + +func (self *Xgress) markCircuitEndReceived() { + self.flags.Set(endOfCircuitRecvdFlag, true) +} + +func (self *Xgress) IsCircuitStarted() bool { + return !self.IsTerminator() || self.flags.IsSet(rxerStartedFlag) +} + +func (self *Xgress) firstCircuitStartReceived() bool { + return self.flags.CompareAndSet(rxerStartedFlag, false, true) +} + +func (self *Xgress) Start() { + log := pfxlog.ContextLogger(self.Label()) + if self.IsTerminator() { + log.Debug("terminator: waiting for circuit start before starting receiver") + if self.Options.CircuitStartTimeout > time.Second { + time.AfterFunc(self.Options.CircuitStartTimeout, self.terminateIfNotStarted) + } + } else { + log.Debug("initiator: sending circuit start") + self.forwardPayload(self.GetStartCircuit()) + go self.rx() + } + go self.tx() +} + +func (self *Xgress) terminateIfNotStarted() { + if !self.IsCircuitStarted() { + logrus.WithField("xgress", self.Label()).Warn("xgress circuit not started in time, closing") + self.Close() + } +} + +func (self *Xgress) Label() string { + return fmt.Sprintf("{c/%s|@/%s}<%s>", self.circuitId, string(self.address), self.originator.String()) +} + +func (self *Xgress) GetStartCircuit() *Payload { + startCircuit := &Payload{ + CircuitId: self.circuitId, + Flags: SetOriginatorFlag(uint32(PayloadFlagCircuitStart), self.originator), + Sequence: int32(self.nextReceiveSequence()), + Data: nil, + } + return startCircuit +} + +func (self *Xgress) GetEndCircuit() *Payload { + endCircuit := &Payload{ + CircuitId: self.circuitId, + Flags: SetOriginatorFlag(uint32(PayloadFlagCircuitEnd), self.originator), + Sequence: int32(self.nextReceiveSequence()), + Data: nil, + } + return endCircuit +} + +func (self *Xgress) ForwardEndOfCircuit(sendF func(payload *Payload) bool) { + // for now always send end of circuit. too many is better than not enough + if !self.IsEndOfCircuitSent() { + sendF(self.GetEndCircuit()) + self.flags.Set(endOfCircuitSentFlag, true) + } +} + +func (self *Xgress) IsEndOfCircuitSent() bool { + return self.flags.IsSet(endOfCircuitSentFlag) +} + +func (self *Xgress) CloseTimeout(duration time.Duration) { + if self.payloadBuffer.CloseWhenEmpty() { // If we clear the send buffer, close sooner + time.AfterFunc(duration, self.Close) + } +} + +func (self *Xgress) Unrouted() { + // When we're unrouted, if end of circuit hasn't already arrived, give incoming/queued data + // a chance to outflow before closing + if !self.flags.IsSet(closedFlag) { + self.payloadBuffer.Close() + time.AfterFunc(self.Options.MaxCloseWait, self.Close) + } +} + +/* +Things which can trigger close + +1. Read fails +2. Write fails +3. End of Circuit received +4. Unroute received +*/ +func (self *Xgress) Close() { + log := pfxlog.ContextLogger(self.Label()) + + if self.flags.CompareAndSet(closedFlag, false, true) { + log.Debug("closing xgress peer") + if err := self.peer.Close(); err != nil { + log.WithError(err).Warn("error while closing xgress peer") + } + + log.Debug("closing tx queue") + close(self.closeNotify) + + self.payloadBuffer.Close() + + for _, peekHandler := range self.peekHandlers { + peekHandler.Close(self) + } + + if len(self.closeHandlers) != 0 { + for _, closeHandler := range self.closeHandlers { + closeHandler.HandleXgressClose(self) + } + } else { + pfxlog.ContextLogger(self.Label()).Warn("no close handler") + } + } +} + +func (self *Xgress) Closed() bool { + return self.flags.IsSet(closedFlag) +} + +func (self *Xgress) SendPayload(payload *Payload, _ time.Duration, _ PayloadType) error { + if self.Closed() { + return nil + } + + if payload.IsCircuitEndFlagSet() { + pfxlog.ContextLogger(self.Label()).Debug("received end of circuit Payload") + } + atomic.StoreInt64(&self.timeOfLastRxFromLink, time.Now().UnixMilli()) + self.dataPlane.GetPayloadIngester().ingest(payload, self) + + return nil +} + +func (self *Xgress) SendAcknowledgement(acknowledgement *Acknowledgement) error { + self.dataPlane.GetMetrics().MarkAckReceived() + self.payloadBuffer.ReceiveAcknowledgement(acknowledgement) + return nil +} + +func (self *Xgress) SendControl(control *Control) error { + return self.peer.HandleControlMsg(control.Type, control.Headers, self) +} + +func (self *Xgress) HandleControlReceive(controlType ControlType, headers channel.Headers) { + control := &Control{ + Type: controlType, + CircuitId: self.circuitId, + Headers: headers, + } + self.dataPlane.ForwardControlMessage(control, self) +} + +func (self *Xgress) payloadIngester(payload *Payload) { + if payload.IsCircuitStartFlagSet() && self.firstCircuitStartReceived() { + go self.rx() + } + + if !self.Options.RandomDrops || rand.Int31n(self.Options.Drop1InN) != 1 { + self.PayloadReceived(payload) + } + self.queueSends() +} + +func (self *Xgress) queueSends() { + payload := self.linkRxBuffer.PeekHead() + for payload != nil { + select { + case self.txQueue <- payload: + self.linkRxBuffer.Remove(payload) + payload = self.linkRxBuffer.PeekHead() + default: + payload = nil + } + } +} + +func (self *Xgress) nextPayload() *Payload { + select { + case payload := <-self.txQueue: + return payload + default: + } + + // nothing was available in the txQueue, request more, then wait on txQueue + self.dataPlane.GetPayloadIngester().payloadSendReq <- self + + select { + case payload := <-self.txQueue: + return payload + case <-self.closeNotify: + } + + // closed, check if there's anything pending in the queue + select { + case payload := <-self.txQueue: + return payload + default: + return nil + } +} + +func (self *Xgress) tx() { + log := pfxlog.ContextLogger(self.Label()) + + log.Debug("started") + defer log.Debug("exited") + defer func() { + if self.IsEndOfCircuitReceived() { + self.Close() + } else { + self.flushSendThenClose() + } + }() + + clearPayloadFromSendBuffer := func(payload *Payload) { + payloadSize := len(payload.Data) + size := atomic.AddUint32(&self.linkRxBuffer.size, ^uint32(payloadSize-1)) // subtraction for uint32 + + payloadLogger := log.WithFields(payload.GetLoggerFields()) + payloadLogger.Debugf("payload %v of size %v removed from rx buffer, new size: %v", payload.Sequence, payloadSize, size) + + lastBufferSizeSent := self.linkRxBuffer.getLastBufferSizeSent() + if lastBufferSizeSent > 10000 && (lastBufferSizeSent>>1) > size { + self.SendEmptyAck() + } + } + + sendPayload := func(payload *Payload) bool { + payloadLogger := log.WithFields(payload.GetLoggerFields()) + + if payload.IsCircuitEndFlagSet() { + self.markCircuitEndReceived() + payloadLogger.Debug("circuit end payload received, exiting") + return false + } + + payloadLogger.Debug("sending") + + for _, peekHandler := range self.peekHandlers { + peekHandler.Tx(self, payload) + } + + if !payload.IsCircuitStartFlagSet() { + start := time.Now() + n, err := self.peer.WritePayload(payload.Data, payload.Headers) + if err != nil { + payloadLogger.Warnf("write failed (%s), closing xgress", err) + self.Close() + return false + } else { + self.dataPlane.GetMetrics().PayloadWritten(time.Since(start)) + payloadLogger.Debugf("payload sent [%s]", info.ByteCount(int64(n))) + } + } + return true + } + + var payload *Payload + var payloadChunk *Payload + + payloadStarted := false + payloadComplete := false + var payloadSize uint64 + var payloadWriteOffset int + + for { + payloadChunk = self.nextPayload() + + if payloadChunk == nil { + log.Debug("nil payload received, exiting") + return + } + + if !isFlagSet(payloadChunk.GetFlags(), PayloadFlagChunk) { + if !sendPayload(payloadChunk) { + return + } + clearPayloadFromSendBuffer(payloadChunk) + continue + } + + var payloadReadOffset int + if !payloadStarted { + payloadSize, payloadReadOffset = binary.Uvarint(payloadChunk.Data) + + if len(payloadChunk.Data) == 0 || payloadSize+uint64(payloadReadOffset) == uint64(len(payloadChunk.Data)) { + payload = payloadChunk + payload.Data = payload.Data[payloadReadOffset:] + payloadComplete = true + } else { + payload = &Payload{ + CircuitId: payloadChunk.CircuitId, + Flags: payloadChunk.Flags, + RTT: payloadChunk.RTT, + Sequence: payloadChunk.Sequence, + Headers: payloadChunk.Headers, + Data: make([]byte, payloadSize), + } + } + payloadStarted = true + } + + if !payloadComplete { + chunkData := payloadChunk.Data[payloadReadOffset:] + copy(payload.Data[payloadWriteOffset:], chunkData) + payloadWriteOffset += len(chunkData) + payloadComplete = uint64(payloadWriteOffset) == payloadSize + } + + payloadLogger := log.WithFields(payload.GetLoggerFields()) + payloadLogger.Debugf("received payload chunk. seq: %d, first: %v, complete: %v, chunk size: %d, payload size: %d, writeOffset: %d", + payloadChunk.Sequence, len(payload.Data) == 0 || payloadReadOffset > 0, payloadComplete, len(payloadChunk.Data), payloadSize, payloadWriteOffset) + + if !payloadComplete { + clearPayloadFromSendBuffer(payloadChunk) + continue + } + + payloadStarted = false + payloadComplete = false + payloadWriteOffset = 0 + + if !sendPayload(payload) { + return + } + clearPayloadFromSendBuffer(payloadChunk) + } +} + +func (self *Xgress) flushSendThenClose() { + self.CloseTimeout(self.Options.MaxCloseWait) + self.ForwardEndOfCircuit(func(payload *Payload) bool { + if self.payloadBuffer.closed.Load() { + // Avoid spurious 'failed to forward payload' error if the buffer is already closed + return false + } + + pfxlog.ContextLogger(self.Label()).Info("sending end of circuit payload") + return self.forwardPayload(payload) + }) +} + +/** + Payload format + + Field 1: 1 byte - version and flags + Masks + * 00000000 - Always 0 to indicate type. The standard channel header 4 byte protocol indicator has a 1 in bit 0 of the first byte + * 00000110 - Version, v0-v3. Assumption is that if we ever get to v4, we can roll back to 0, since everything + should have upgraded past v0 by that point + * 00001000 - Terminator Flag - indicates the payload origin, initiator (0) or terminator (1) + * 00010000 - RTT Flag. Indicates if the payload contains an RTT. We don't need to send RTT on every payload. + * 00100000 - Chunk Flag. Indicates if this payload is chunked. + * 01000000 - Headers flag. Indicates this payload contains headers. + * 10000000 - Heartbeat Flag. Indicates the payload contains a heartbeat + + Field 2: 1 byte, Circuit id size + Masks + * 00001111 - Number of bytes in circuit id. Supports circuit ids which take up to 15 bytes. + Circuits ids are currently at 9 bytes. + * 11110000 - currently unused + + Field 3: RTT (optional) + - 2 bytes + + Field 4: CircuitId + - direct bytes representation of string encoded circuit id + + Field 5: Sequence number + - Encoded using binary.PutUvarint + + Field 6: Headers + - Presence indicated by headers flag in first field + length - encoded with binary.PutUvarint + for each key/value pair - + key - 1 byte + value length - encoded with binary.PutUvarint + value - byte array, directly appended + + + Field 7: Data + + Field 8: Heartbeat + - 8 bytes + - only included if there's extra room +*/ + +const ( + VersionMask byte = 0b00000110 + TerminatorFlagMask byte = 0b00001000 + RttFlagMask byte = 0b00010000 + ChunkFlagMask byte = 0b00100000 + HeadersFlagMask byte = 0b01000000 + HeartbeatFlagMask byte = 0b10000000 + + CircuitIdSizeMask byte = 0b00001111 + PayloadProtocolV1 byte = 1 + PayloadProtocolOffset byte = 1 +) + +func (self *Xgress) rx() { + log := pfxlog.ContextLogger(self.Label()) + + log.Debugf("started with peer: %v", self.peer.LogContext()) + defer log.Debug("exited") + + defer func() { + if r := recover(); r != nil { + log.Errorf("send on closed channel. error: (%v)", r) + return + } + }() + + defer self.flushSendThenClose() + + for { + buffer, headers, err := self.peer.ReadPayload() + log.Debugf("payload read: %d bytes read", len(buffer)) + n := len(buffer) + + // if we got an EOF, but also some data, ignore the EOF, next read we'll get 0, EOF + if err != nil && (n == 0 || err != io.EOF) { + if err == io.EOF { + log.Debugf("EOF, exiting xgress.rx loop") + } else { + log.Warnf("read failed (%s)", err) + } + + return + } + + if self.Closed() { + return + } + + if self.Options.Mtu == 0 { + if !self.sendUnchunkedBuffer(buffer, headers) { + return + } + continue + } + + first := true + chunked := false + for len(buffer) > 0 || (first && len(headers) > 0) { + seq := self.nextReceiveSequence() + + chunk := make([]byte, self.Options.Mtu) + + flagsHeader := VersionMask & (PayloadProtocolV1 << PayloadProtocolOffset) + var sizesHeader byte + if self.originator == Terminator { + flagsHeader |= TerminatorFlagMask + } + + written := 2 + rest := chunk[2:] + includeRtt := seq%5 == 0 + if includeRtt { + flagsHeader |= RttFlagMask + written += 2 + rest = rest[2:] + } + + size := copy(rest, self.circuitId) + sizesHeader |= CircuitIdSizeMask & uint8(size) + written += size + rest = rest[size:] + size = binary.PutUvarint(rest, seq) + rest = rest[size:] + written += size + + if first && len(headers) > 0 { + flagsHeader |= HeadersFlagMask + size, err = writeU8ToBytesMap(headers, rest) + if err != nil { + log.WithError(err).Error("payload encoding error, closing") + return + } + rest = rest[size:] + written += size + } + + data := rest + dataLen := 0 + if first && len(rest) < len(buffer) { + chunked = true + size = binary.PutUvarint(rest, uint64(n)) + dataLen += size + written += size + rest = rest[size:] + } + + if chunked { + flagsHeader |= ChunkFlagMask + } + + size = copy(rest, buffer) + written += size + dataLen += size + + buffer = buffer[size:] + + // check if there's room for a heartbeat + if written+8 <= len(chunk) { + flagsHeader |= HeartbeatFlagMask + written += 8 + } + + chunk[0] = flagsHeader + chunk[1] = sizesHeader + + payload := &Payload{ + CircuitId: self.circuitId, + Flags: SetOriginatorFlag(0, self.originator), + Sequence: int32(seq), + Data: data[:dataLen], + raw: chunk[:written], + } + + if chunked { + payload.Flags = setPayloadFlag(payload.Flags, PayloadFlagChunk) + } + + if first { + payload.Headers = headers + } + + log.Debugf("sending payload chunk. seq: %d, first: %v, chunk size: %d, payload size: %d, remainder: %d", payload.Sequence, first, len(payload.Data), n, len(buffer)) + first = false + + // if the payload buffer is closed, we can't forward any more data, so might as well exit the rx loop + // The txer will still have a chance to flush any already received data + if !self.forwardPayload(payload) { + return + } + + payloadLogger := log.WithFields(payload.GetLoggerFields()) + payloadLogger.Debugf("forwarded [%s]", info.ByteCount(int64(n))) + } + + logrus.Debugf("received payload for [%d] bytes", n) + } +} + +func (self *Xgress) sendUnchunkedBuffer(buf []byte, headers map[uint8][]byte) bool { + log := pfxlog.ContextLogger(self.Label()) + + payload := &Payload{ + CircuitId: self.circuitId, + Flags: SetOriginatorFlag(0, self.originator), + Sequence: int32(self.nextReceiveSequence()), + Data: buf, + Headers: headers, + } + + log.Debugf("sending unchunked payload. seq: %d, payload size: %d", payload.Sequence, len(payload.Data)) + + // if the payload buffer is closed, we can't forward any more data, so might as well exit the rx loop + // The txer will still have a chance to flush any already received data + if !self.forwardPayload(payload) { + return false + } + + payloadLogger := log.WithFields(payload.GetLoggerFields()) + payloadLogger.Debugf("forwarded [%s]", info.ByteCount(int64(len(buf)))) + return true +} + +func (self *Xgress) forwardPayload(payload *Payload) bool { + sendCallback, err := self.payloadBuffer.BufferPayload(payload) + + if err != nil { + pfxlog.ContextLogger(self.Label()).WithError(err).Error("failure to buffer payload") + return false + } + + for _, peekHandler := range self.peekHandlers { + peekHandler.Rx(self, payload) + } + + self.dataPlane.ForwardPayload(payload, self) + sendCallback() + return true +} + +func (self *Xgress) nextReceiveSequence() uint64 { + self.rxSequenceLock.Lock() + defer self.rxSequenceLock.Unlock() + + next := self.rxSequence + self.rxSequence++ + + return next +} + +func (self *Xgress) PayloadReceived(payload *Payload) { + log := pfxlog.ContextLogger(self.Label()).WithFields(payload.GetLoggerFields()) + log.Debug("payload received") + if self.originator == payload.GetOriginator() { + // a payload sent from this xgress has arrived back at this xgress, instead of the other end + log.Warn("ouroboros (circuit cycle) detected, dropping payload") + } else if self.linkRxBuffer.ReceiveUnordered(self, payload, self.Options.RxBufferSize) { + log.Debug("ready to acknowledge") + + ack := NewAcknowledgement(self.circuitId, self.originator) + ack.RecvBufferSize = self.linkRxBuffer.Size() + ack.Sequence = append(ack.Sequence, payload.Sequence) + ack.RTT = payload.RTT + + atomic.StoreUint32(&self.linkRxBuffer.lastBufferSizeSent, ack.RecvBufferSize) + self.dataPlane.ForwardAcknowledgement(ack, self.address) + } else { + log.Debug("dropped") + } +} + +func (self *Xgress) SendEmptyAck() { + pfxlog.ContextLogger(self.Label()).WithField("circuit", self.circuitId).Debug("sending empty ack") + ack := NewAcknowledgement(self.circuitId, self.originator) + ack.RecvBufferSize = self.linkRxBuffer.Size() + atomic.StoreUint32(&self.linkRxBuffer.lastBufferSizeSent, ack.RecvBufferSize) + self.dataPlane.ForwardAcknowledgement(ack, self.address) +} + +func (self *Xgress) GetSequence() uint64 { + self.rxSequenceLock.Lock() + defer self.rxSequenceLock.Unlock() + return uint64(self.rxSequence) +} + +func (self *Xgress) InspectCircuit(detail *CircuitInspectDetail) { + timeSinceLastRxFromLink := time.Duration(time.Now().UnixMilli()-atomic.LoadInt64(&self.timeOfLastRxFromLink)) * time.Millisecond + xgressDetail := &InspectDetail{ + Address: string(self.address), + Originator: self.originator.String(), + TimeSinceLastLinkRx: timeSinceLastRxFromLink.String(), + SendBufferDetail: self.payloadBuffer.Inspect(), + RecvBufferDetail: self.linkRxBuffer.Inspect(self), + XgressPointer: fmt.Sprintf("%p", self), + LinkSendBufferPointer: fmt.Sprintf("%p", self.payloadBuffer), + Sequence: self.GetSequence(), + Flags: strconv.FormatUint(uint64(self.flags.Load()), 2), + } + + if detail.IncludeGoroutines() { + xgressDetail.Goroutines = self.getRelatedGoroutines(xgressDetail.XgressPointer, xgressDetail.LinkSendBufferPointer) + } + + detail.AddXgressDetail(xgressDetail) +} + +func (self *Xgress) getRelatedGoroutines(contains ...string) []string { + reader := bytes.NewBufferString(debugz.GenerateStack()) + scanner := bufio.NewScanner(reader) + var result []string + var buf *bytes.Buffer + xgressRelated := false + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "goroutine") && strings.HasSuffix(line, ":") { + result = self.addGoroutineIfRelated(buf, xgressRelated, result, contains...) + buf = &bytes.Buffer{} + xgressRelated = false + } + + if buf != nil { + if strings.Contains(line, "xgress") { + xgressRelated = true + } + buf.WriteString(line) + buf.WriteByte('\n') + } + } + result = self.addGoroutineIfRelated(buf, xgressRelated, result, contains...) + if err := scanner.Err(); err != nil { + result = append(result, "goroutine parsing error: %v", err.Error()) + } + return result +} + +func (self *Xgress) addGoroutineIfRelated(buf *bytes.Buffer, xgressRelated bool, result []string, contains ...string) []string { + if !xgressRelated { + return result + } + if buf != nil { + gr := buf.String() + // ignore the current goroutine + if strings.Contains(gr, "GenerateStack") { + return result + } + + for _, s := range contains { + if strings.Contains(gr, s) { + result = append(result, gr) + break + } + } + } + return result +} + +func UnmarshallPacketPayload(buf []byte) (*channel.Message, error) { + flagsField := buf[0] + if flagsField&1 != 0 { + return channel.ReadV2(bytes.NewBuffer(buf)) + } + version := (flagsField & VersionMask) >> 1 + if version != PayloadProtocolV1 { + return nil, fmt.Errorf("unsupported version: %d", version) + } + sizeField := buf[1] + circuitIdSize := CircuitIdSizeMask & sizeField + rest := buf[2:] + + var rtt *uint16 + if flagsField&RttFlagMask != 0 { + b0 := rest[0] + b1 := rest[1] + rest = rest[2:] + val := uint16(b0) | (uint16(b1) << 8) + rtt = &val + } + + var heartbeat *uint64 + if flagsField&HeartbeatFlagMask != 0 { + val := binary.BigEndian.Uint64(rest[len(rest)-8:]) + heartbeat = &val + rest = rest[:len(rest)-8] + } + + circuitId := string(rest[:circuitIdSize]) + rest = rest[circuitIdSize:] + seq, read := binary.Uvarint(rest) + rest = rest[read:] + + var headers map[uint8][]byte + if flagsField&HeadersFlagMask != 0 { + var err error + headers, rest, err = readU8ToBytesMap(rest) + if err != nil { + return nil, err + } + } + + msg := channel.NewMessage(ContentTypePayloadType, rest) + addPayloadHeadersToMsg(msg, headers) + msg.PutStringHeader(HeaderKeyCircuitId, circuitId) + msg.PutUint64Header(HeaderKeySequence, seq) + if heartbeat != nil { + msg.PutUint64Header(channel.HeartbeatHeader, *heartbeat) + } + msg.Headers[HeaderPayloadRaw] = buf + + flags := uint32(0) + + if flagsField&ChunkFlagMask != 0 { + flags = setPayloadFlag(flags, PayloadFlagChunk) + } + + if flagsField&TerminatorFlagMask != 0 { + flags = setPayloadFlag(flags, PayloadFlagOriginator) + } + + if flags != 0 { + msg.PutUint32Header(HeaderKeyFlags, flags) + } + + if rtt != nil { + msg.PutUint16Header(HeaderKeyRTT, *rtt) + } + + return msg, nil +} + +func writeU8ToBytesMap(m map[uint8][]byte, buf []byte) (int, error) { + written := binary.PutUvarint(buf, uint64(len(m))) + buf = buf[written:] + for k, v := range m { + if len(buf) < 10 { + return 0, fmt.Errorf("header too large, no space for header keys, payload has only %d bytes left", len(buf)) + } + buf[0] = k + written++ + buf = buf[1:] + + fieldLen := binary.PutUvarint(buf, uint64(len(v))) + buf = buf[fieldLen:] + written += fieldLen + if len(buf) < len(v) { + return 0, fmt.Errorf("header too large, no space for header value of size %d, only %d bytes available", len(v), len(buf)) + } + + fieldLen = copy(buf, v) + buf = buf[fieldLen:] + written += fieldLen + } + + return written, nil +} + +func readU8ToBytesMap(buf []byte) (map[uint8][]byte, []byte, error) { + result := map[uint8][]byte{} + count, offset := binary.Uvarint(buf) + if offset < 1 { + return nil, nil, errors.New("error reading payload header map length") + } + buf = buf[offset:] + for i := range count { + if len(buf) < 2 { + return nil, nil, fmt.Errorf("payload header error, ran out of space reading header %d", i) + } + k := buf[0] + valSize, read := binary.Uvarint(buf[1:]) + if read < 1 { + return nil, nil, fmt.Errorf("payload header error, ran out of space reading header %d", i) + } + buf = buf[read+1:] + if len(buf) < int(valSize) { + return nil, nil, fmt.Errorf("payload header error, ran out of space reading header %d", i) + } + result[k] = buf[:valSize] + buf = buf[valSize:] + } + + return result, buf, nil +} diff --git a/ziti/sdkinfo/build_info.go b/ziti/sdkinfo/build_info.go index 9d515295..2f1a8806 100644 --- a/ziti/sdkinfo/build_info.go +++ b/ziti/sdkinfo/build_info.go @@ -20,5 +20,5 @@ package sdkinfo const ( - Version = "v1.0.2" + Version = "v1.0.3" )