diff --git a/src/NATS.Client.Core/Commands/CommandWriter.cs b/src/NATS.Client.Core/Commands/CommandWriter.cs index 818b7b76..31cc150d 100644 --- a/src/NATS.Client.Core/Commands/CommandWriter.cs +++ b/src/NATS.Client.Core/Commands/CommandWriter.cs @@ -1,229 +1,150 @@ +using System.Buffers; using System.IO.Pipelines; using System.Runtime.CompilerServices; using System.Threading.Channels; +using Microsoft.Extensions.Logging; using NATS.Client.Core.Internal; namespace NATS.Client.Core.Commands; /// -/// Used to track commands that have been enqueued to the PipeReader -/// -internal readonly record struct QueuedCommand(int Size, int Trim = 0); - -/// -/// Sets up a Pipe, and provides methods to write to the PipeWriter +/// Sets up a buffer (Pipe), and provides methods to write protocol messages to the buffer /// When methods complete, they have been queued for sending /// and further cancellation is not possible /// /// /// These methods are in the hot path, and have all been -/// optimized to eliminate allocations and make an initial attempt -/// to run synchronously without the async state machine +/// optimized to eliminate allocations and minimize copying /// internal sealed class CommandWriter : IAsyncDisposable { + // set to a reasonable socket write mem size + private const int MaxSendSize = 16384; + + private readonly ILogger _logger; + private readonly ObjectPool _pool; + private readonly int _arrayPoolInitialSize; + private readonly object _lock = new(); + private readonly CancellationTokenSource _cts; private readonly ConnectionStatsCounter _counter; private readonly TimeSpan _defaultCommandTimeout; private readonly Action _enqueuePing; - private readonly NatsOpts _opts; - private readonly PipeWriter _pipeWriter; private readonly ProtocolWriter _protocolWriter; - private readonly ChannelWriter _queuedCommandsWriter; - private readonly SemaphoreSlim _semLock; + private readonly HeaderWriter _headerWriter; + private readonly Channel _channelLock; + private readonly Channel _channelSize; + private readonly CancellationTimerPool _ctPool; + private readonly PipeReader _pipeReader; + private readonly PipeWriter _pipeWriter; + private ISocketConnection? _socketConnection; private Task? _flushTask; - private bool _disposed; + private Task? _readerLoopTask; + private CancellationTokenSource? _ctsReader; + private volatile bool _disposed; - public CommandWriter(NatsOpts opts, ConnectionStatsCounter counter, Action enqueuePing, TimeSpan? overrideCommandTimeout = default) + public CommandWriter(ObjectPool pool, NatsOpts opts, ConnectionStatsCounter counter, Action enqueuePing, TimeSpan? overrideCommandTimeout = default) { + _logger = opts.LoggerFactory.CreateLogger(); + _pool = pool; + + // Derive ArrayPool rent size from buffer size to + // avoid defining another option. + _arrayPoolInitialSize = opts.WriterBufferSize / 256; + _counter = counter; _defaultCommandTimeout = overrideCommandTimeout ?? opts.CommandTimeout; _enqueuePing = enqueuePing; - _opts = opts; + _protocolWriter = new ProtocolWriter(opts.SubjectEncoding); + _channelLock = Channel.CreateBounded(1); + _channelSize = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleWriter = true, SingleReader = true }); + _headerWriter = new HeaderWriter(opts.HeaderEncoding); + _cts = new CancellationTokenSource(); + var pipe = new Pipe(new PipeOptions( pauseWriterThreshold: opts.WriterBufferSize, // flush will block after hitting - resumeWriterThreshold: opts.WriterBufferSize / 2, // will start flushing again after catching up - minimumSegmentSize: 16384, // segment that is part of an uninterrupted payload can be sent using socket.send + resumeWriterThreshold: opts.WriterBufferSize / 2, useSynchronizationContext: false)); - PipeReader = pipe.Reader; + _pipeReader = pipe.Reader; _pipeWriter = pipe.Writer; - _protocolWriter = new ProtocolWriter(_pipeWriter, opts.SubjectEncoding, opts.HeaderEncoding); - var channel = Channel.CreateUnbounded(new UnboundedChannelOptions { SingleWriter = true, SingleReader = true }); - _semLock = new SemaphoreSlim(1); - QueuedCommandsReader = channel.Reader; - _queuedCommandsWriter = channel.Writer; - } - - public PipeReader PipeReader { get; } - - public ChannelReader QueuedCommandsReader { get; } - - public Queue InFlightCommands { get; } = new(); - - public async ValueTask DisposeAsync() - { - await _semLock.WaitAsync().ConfigureAwait(false); - try - { - if (_disposed) - { - return; - } - _disposed = true; - _queuedCommandsWriter.Complete(); - await _pipeWriter.CompleteAsync().ConfigureAwait(false); - } - finally - { - _semLock.Release(); - } + // We need a new ObjectPool here because of the root token (_cts.Token). + // When the root token is cancelled as this object is disposed, cancellation + // objects in the pooled CancellationTimer should not be reused since the + // root token would already be cancelled which means CancellationTimer tokens + // would always be in a cancelled state. + _ctPool = new CancellationTimerPool(new ObjectPool(opts.ObjectPoolSize), _cts.Token); } - public NatsPipeliningWriteProtocolProcessor CreateNatsPipeliningWriteProtocolProcessor(ISocketConnection socketConnection) => new(socketConnection, this, _opts, _counter); - - public ValueTask ConnectAsync(ClientOpts connectOpts, CancellationToken cancellationToken) + public void Reset(ISocketConnection socketConnection) { -#pragma warning disable CA2016 -#pragma warning disable VSTHRD103 - if (!_semLock.Wait(0)) -#pragma warning restore VSTHRD103 -#pragma warning restore CA2016 - { - return ConnectStateMachineAsync(false, connectOpts, cancellationToken); - } - - if (_flushTask is { IsCompletedSuccessfully: false }) - { - return ConnectStateMachineAsync(true, connectOpts, cancellationToken); - } - - try + lock (_lock) { - if (_disposed) - { - throw new ObjectDisposedException(nameof(CommandWriter)); - } + _socketConnection = socketConnection; + _ctsReader = new CancellationTokenSource(); - var success = false; - try - { - _protocolWriter.WriteConnect(connectOpts); - success = true; - } - finally + _readerLoopTask = Task.Run(async () => { - EnqueueCommand(success); - } - } - finally - { - _semLock.Release(); + await ReaderLoopAsync(_logger, _socketConnection, _pipeReader, _channelSize, _ctsReader.Token).ConfigureAwait(false); + }); } - - return ValueTask.CompletedTask; } - public ValueTask PingAsync(PingCommand pingCommand, CancellationToken cancellationToken) + public async Task CancelReaderLoopAsync() { -#pragma warning disable CA2016 -#pragma warning disable VSTHRD103 - if (!_semLock.Wait(0)) -#pragma warning restore VSTHRD103 -#pragma warning restore CA2016 - { - return PingStateMachineAsync(false, pingCommand, cancellationToken); - } - - if (_flushTask is { IsCompletedSuccessfully: false }) + CancellationTokenSource? cts; + Task? readerTask; + lock (_lock) { - return PingStateMachineAsync(true, pingCommand, cancellationToken); + cts = _ctsReader; + readerTask = _readerLoopTask; } - try + if (cts != null) { - if (_disposed) - { - throw new ObjectDisposedException(nameof(CommandWriter)); - } - - var success = false; - try - { - _protocolWriter.WritePing(); - _enqueuePing(pingCommand); - success = true; - } - finally - { - EnqueueCommand(success); - } - } - finally - { - _semLock.Release(); +#if NET6_0 + cts.Cancel(); +#else + await cts.CancelAsync().ConfigureAwait(false); +#endif } - return ValueTask.CompletedTask; + if (readerTask != null) + await readerTask.WaitAsync(TimeSpan.FromSeconds(3), _cts.Token).ConfigureAwait(false); } - public ValueTask PongAsync(CancellationToken cancellationToken = default) + public async ValueTask DisposeAsync() { -#pragma warning disable CA2016 -#pragma warning disable VSTHRD103 - if (!_semLock.Wait(0)) -#pragma warning restore VSTHRD103 -#pragma warning restore CA2016 + if (_disposed) { - return PongStateMachineAsync(false, cancellationToken); + return; } - if (_flushTask is { IsCompletedSuccessfully: false }) - { - return PongStateMachineAsync(true, cancellationToken); - } + _disposed = true; - try - { - if (_disposed) - { - throw new ObjectDisposedException(nameof(CommandWriter)); - } +#if NET6_0 + _cts.Cancel(); +#else + await _cts.CancelAsync().ConfigureAwait(false); +#endif - var success = false; - try - { - _protocolWriter.WritePong(); - success = true; - } - finally - { - EnqueueCommand(success); - } - } - finally + _channelLock.Writer.TryComplete(); + _channelSize.Writer.TryComplete(); + await _pipeWriter.CompleteAsync().ConfigureAwait(false); + + Task? readerTask; + lock (_lock) { - _semLock.Release(); + readerTask = _readerLoopTask; } - return ValueTask.CompletedTask; + if (readerTask != null) + await readerTask.ConfigureAwait(false); } - public ValueTask PublishAsync(string subject, T? value, NatsHeaders? headers, string? replyTo, INatsSerialize serializer, CancellationToken cancellationToken) + public async ValueTask ConnectAsync(ClientOpts connectOpts, CancellationToken cancellationToken) { -#pragma warning disable CA2016 -#pragma warning disable VSTHRD103 - if (!_semLock.Wait(0)) -#pragma warning restore VSTHRD103 -#pragma warning restore CA2016 - { - return PublishStateMachineAsync(false, subject, value, headers, replyTo, serializer, cancellationToken); - } - - if (_flushTask is { IsCompletedSuccessfully: false }) - { - return PublishStateMachineAsync(true, subject, value, headers, replyTo, serializer, cancellationToken); - } - + var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken); + await LockAsync(cancellationTimer.Token).ConfigureAwait(false); try { if (_disposed) @@ -231,84 +152,28 @@ public ValueTask PublishAsync(string subject, T? value, NatsHeaders? headers, throw new ObjectDisposedException(nameof(CommandWriter)); } - var trim = 0; - var success = false; - try - { - trim = _protocolWriter.WritePublish(subject, value, headers, replyTo, serializer); - success = true; - } - finally + if (_flushTask is { IsCompletedSuccessfully: false }) { - EnqueueCommand(success, trim: trim); + await _flushTask.WaitAsync(cancellationTimer.Token).ConfigureAwait(false); } - } - finally - { - _semLock.Release(); - } - return ValueTask.CompletedTask; - } - - public ValueTask SubscribeAsync(int sid, string subject, string? queueGroup, int? maxMsgs, CancellationToken cancellationToken) - { -#pragma warning disable CA2016 -#pragma warning disable VSTHRD103 - if (!_semLock.Wait(0)) -#pragma warning restore VSTHRD103 -#pragma warning restore CA2016 - { - return SubscribeStateMachineAsync(false, sid, subject, queueGroup, maxMsgs, cancellationToken); - } - - if (_flushTask is { IsCompletedSuccessfully: false }) - { - return SubscribeStateMachineAsync(true, sid, subject, queueGroup, maxMsgs, cancellationToken); - } - - try - { - if (_disposed) - { - throw new ObjectDisposedException(nameof(CommandWriter)); - } + _protocolWriter.WriteConnect(_pipeWriter, connectOpts); - var success = false; - try - { - _protocolWriter.WriteSubscribe(sid, subject, queueGroup, maxMsgs); - success = true; - } - finally - { - EnqueueCommand(success); - } + _channelSize.Writer.TryWrite((int)_pipeWriter.UnflushedBytes); + var flush = _pipeWriter.FlushAsync(CancellationToken.None); + _flushTask = flush.IsCompletedSuccessfully ? null : flush.AsTask(); } finally { - _semLock.Release(); + await UnLockAsync().ConfigureAwait(false); + cancellationTimer.TryReturn(); } - - return ValueTask.CompletedTask; } - public ValueTask UnsubscribeAsync(int sid, CancellationToken cancellationToken) + public async ValueTask PingAsync(PingCommand pingCommand, CancellationToken cancellationToken) { -#pragma warning disable CA2016 -#pragma warning disable VSTHRD103 - if (!_semLock.Wait(0)) -#pragma warning restore VSTHRD103 -#pragma warning restore CA2016 - { - return UnsubscribeStateMachineAsync(false, sid, cancellationToken); - } - - if (_flushTask is { IsCompletedSuccessfully: false }) - { - return UnsubscribeStateMachineAsync(true, sid, cancellationToken); - } - + var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken); + await LockAsync(cancellationTimer.Token).ConfigureAwait(false); try { if (_disposed) @@ -316,55 +181,29 @@ public ValueTask UnsubscribeAsync(int sid, CancellationToken cancellationToken) throw new ObjectDisposedException(nameof(CommandWriter)); } - var success = false; - try - { - _protocolWriter.WriteUnsubscribe(sid, null); - success = true; - } - finally - { - EnqueueCommand(success); - } - } - finally - { - _semLock.Release(); - } - - return ValueTask.CompletedTask; - } - - // only used for internal testing - internal async Task TestStallFlushAsync(TimeSpan timeSpan) - { - await _semLock.WaitAsync().ConfigureAwait(false); - - try - { if (_flushTask is { IsCompletedSuccessfully: false }) { - await _flushTask.ConfigureAwait(false); + await _flushTask.WaitAsync(cancellationTimer.Token).ConfigureAwait(false); } - _flushTask = Task.Delay(timeSpan); + _enqueuePing(pingCommand); + _protocolWriter.WritePing(_pipeWriter); + + _channelSize.Writer.TryWrite((int)_pipeWriter.UnflushedBytes); + var flush = _pipeWriter.FlushAsync(CancellationToken.None); + _flushTask = flush.IsCompletedSuccessfully ? null : flush.AsTask(); } finally { - _semLock.Release(); + await UnLockAsync().ConfigureAwait(false); + cancellationTimer.TryReturn(); } } - private async ValueTask ConnectStateMachineAsync(bool lockHeld, ClientOpts connectOpts, CancellationToken cancellationToken) + public async ValueTask PongAsync(CancellationToken cancellationToken = default) { - if (!lockHeld) - { - if (!await _semLock.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false)) - { - throw new TimeoutException(); - } - } - + var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken); + await LockAsync(cancellationTimer.Token).ConfigureAwait(false); try { if (_disposed) @@ -374,76 +213,45 @@ private async ValueTask ConnectStateMachineAsync(bool lockHeld, ClientOpts conne if (_flushTask is { IsCompletedSuccessfully: false }) { - await _flushTask.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false); + await _flushTask.WaitAsync(cancellationTimer.Token).ConfigureAwait(false); } - var success = false; - try - { - _protocolWriter.WriteConnect(connectOpts); - success = true; - } - finally - { - EnqueueCommand(success); - } + _protocolWriter.WritePong(_pipeWriter); + + _channelSize.Writer.TryWrite((int)_pipeWriter.UnflushedBytes); + var flush = _pipeWriter.FlushAsync(CancellationToken.None); + _flushTask = flush.IsCompletedSuccessfully ? null : flush.AsTask(); } finally { - _semLock.Release(); + await UnLockAsync().ConfigureAwait(false); + cancellationTimer.TryReturn(); } } - private async ValueTask PingStateMachineAsync(bool lockHeld, PingCommand pingCommand, CancellationToken cancellationToken) + public ValueTask PublishAsync(string subject, T? value, NatsHeaders? headers, string? replyTo, INatsSerialize serializer, CancellationToken cancellationToken) { - if (!lockHeld) + NatsPooledBufferWriter? headersBuffer = null; + if (headers != null) { - if (!await _semLock.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false)) - { - throw new TimeoutException(); - } + if (!_pool.TryRent(out headersBuffer)) + headersBuffer = new NatsPooledBufferWriter(_arrayPoolInitialSize); + _headerWriter.Write(headersBuffer, headers); } - try - { - if (_disposed) - { - throw new ObjectDisposedException(nameof(CommandWriter)); - } - - if (_flushTask is { IsCompletedSuccessfully: false }) - { - await _flushTask.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false); - } + NatsPooledBufferWriter payloadBuffer; + if (!_pool.TryRent(out payloadBuffer!)) + payloadBuffer = new NatsPooledBufferWriter(_arrayPoolInitialSize); + if (value != null) + serializer.Serialize(payloadBuffer, value); - var success = false; - try - { - _protocolWriter.WritePing(); - _enqueuePing(pingCommand); - success = true; - } - finally - { - EnqueueCommand(success); - } - } - finally - { - _semLock.Release(); - } + return PublishLockedAsync(subject, replyTo, payloadBuffer, headersBuffer, cancellationToken); } - private async ValueTask PongStateMachineAsync(bool lockHeld, CancellationToken cancellationToken) + public async ValueTask SubscribeAsync(int sid, string subject, string? queueGroup, int? maxMsgs, CancellationToken cancellationToken) { - if (!lockHeld) - { - if (!await _semLock.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false)) - { - throw new TimeoutException(); - } - } - + var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken); + await LockAsync(cancellationTimer.Token).ConfigureAwait(false); try { if (_disposed) @@ -453,36 +261,26 @@ private async ValueTask PongStateMachineAsync(bool lockHeld, CancellationToken c if (_flushTask is { IsCompletedSuccessfully: false }) { - await _flushTask.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false); + await _flushTask.WaitAsync(cancellationTimer.Token).ConfigureAwait(false); } - var success = false; - try - { - _protocolWriter.WritePong(); - success = true; - } - finally - { - EnqueueCommand(success); - } + _protocolWriter.WriteSubscribe(_pipeWriter, sid, subject, queueGroup, maxMsgs); + + _channelSize.Writer.TryWrite((int)_pipeWriter.UnflushedBytes); + var flush = _pipeWriter.FlushAsync(CancellationToken.None); + _flushTask = flush.IsCompletedSuccessfully ? null : flush.AsTask(); } finally { - _semLock.Release(); + await UnLockAsync().ConfigureAwait(false); + cancellationTimer.TryReturn(); } } - private async ValueTask PublishStateMachineAsync(bool lockHeld, string subject, T? value, NatsHeaders? headers, string? replyTo, INatsSerialize serializer, CancellationToken cancellationToken) + public async ValueTask UnsubscribeAsync(int sid, int? maxMsgs, CancellationToken cancellationToken) { - if (!lockHeld) - { - if (!await _semLock.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false)) - { - throw new TimeoutException(); - } - } - + var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken); + await LockAsync(cancellationTimer.Token).ConfigureAwait(false); try { if (_disposed) @@ -492,167 +290,219 @@ private async ValueTask PublishStateMachineAsync(bool lockHeld, string subjec if (_flushTask is { IsCompletedSuccessfully: false }) { - await _flushTask.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false); + await _flushTask.WaitAsync(cancellationTimer.Token).ConfigureAwait(false); } - var trim = 0; - var success = false; - try - { - trim = _protocolWriter.WritePublish(subject, value, headers, replyTo, serializer); - success = true; - } - finally - { - EnqueueCommand(success, trim: trim); - } + _protocolWriter.WriteUnsubscribe(_pipeWriter, sid, maxMsgs); + + _channelSize.Writer.TryWrite((int)_pipeWriter.UnflushedBytes); + var flush = _pipeWriter.FlushAsync(CancellationToken.None); + _flushTask = flush.IsCompletedSuccessfully ? null : flush.AsTask(); } finally { - _semLock.Release(); + await UnLockAsync().ConfigureAwait(false); + cancellationTimer.TryReturn(); } } - private async ValueTask SubscribeStateMachineAsync(bool lockHeld, int sid, string subject, string? queueGroup, int? maxMsgs, CancellationToken cancellationToken) - { - if (!lockHeld) - { - if (!await _semLock.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false)) - { - throw new TimeoutException(); - } - } + // only used for internal testing + internal bool TestStallFlush() => _channelLock.Writer.TryWrite(1); + private static async Task ReaderLoopAsync(ILogger logger, ISocketConnection connection, PipeReader pipeReader, Channel channelSize, CancellationToken cancellationToken) + { try { - if (_disposed) - { - throw new ObjectDisposedException(nameof(CommandWriter)); - } - - if (_flushTask is { IsCompletedSuccessfully: false }) - { - await _flushTask.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false); - } - - var success = false; - try - { - _protocolWriter.WriteSubscribe(sid, subject, queueGroup, maxMsgs); - success = true; - } - finally - { - EnqueueCommand(success); - } - } - finally - { - _semLock.Release(); + var examinedOffset = 0; + while (true) + { + var result = await pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false); + + if (result.IsCanceled) + { + break; + } + + var buffer = result.Buffer; + var consumed = buffer.Start; + var examined = buffer.GetPosition(examinedOffset); + var readBuffer = buffer.Slice(examinedOffset); + + try + { + if (!buffer.IsEmpty && !readBuffer.IsEmpty) + { + var bufferLength = (int)readBuffer.Length; + + var bytes = ArrayPool.Shared.Rent(bufferLength); + readBuffer.CopyTo(bytes); + var memory = bytes.AsMemory(0, bufferLength); + + try + { + var totalSent = 0; + var totalSize = 0; + while (totalSent < bufferLength) + { + var sendMemory = memory; + if (sendMemory.Length > MaxSendSize) + { + // cap the send size, the OS can only handle so much in a send buffer at a time + // also if the send fails, we have to throw this many bytes away + sendMemory = memory[..MaxSendSize]; + } + + int sent; + Exception? sendEx = null; + try + { + sent = await connection.SendAsync(sendMemory).ConfigureAwait(false); + } + catch (Exception ex) + { + // we have no idea how many bytes were actually sent, so we have to assume they all were + // this could result in message loss, but is consistent with at-most once delivery + sendEx = ex; + sent = sendMemory.Length; + } + + totalSent += sent; + memory = memory[sent..]; + + while (totalSize < totalSent) + { + int peek; + while (!channelSize.Reader.TryPeek(out peek)) + { + // should never happen; channel sizes are written before flush is called + await channelSize.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false); + } + + // Don't just mark the message as complete if we have more data to send + if (totalSize + peek > totalSent) + { + break; + } + + int size; + while (!channelSize.Reader.TryRead(out size)) + { + // should never happen; channel sizes are written before flush is called (plus we just peeked) + await channelSize.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false); + } + + totalSize += size; + examinedOffset = 0; + } + + // make sure to mark the buffer only at message boundaries. + consumed = buffer.GetPosition(totalSize); + examined = buffer.GetPosition(totalSent); + examinedOffset += totalSent - totalSize; + + // throw if there was a send failure + if (sendEx != null) + { + throw sendEx; + } + } + } + finally + { + ArrayPool.Shared.Return(bytes); + } + } + } + finally + { + // Always examine to the end to potentially unblock writer + pipeReader.AdvanceTo(consumed, examined); + } + + if (result.IsCompleted) + { + break; + } + } + } + catch (OperationCanceledException) + { + // Expected during shutdown + } + catch (InvalidOperationException) + { + // We might still be using the previous pipe reader which might be completed already + } + catch (Exception e) + { + logger.LogError(NatsLogEvents.Buffer, e, "Unexpected error in send buffer reader loop"); } } - private async ValueTask UnsubscribeStateMachineAsync(bool lockHeld, int sid, CancellationToken cancellationToken) + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] + private async ValueTask PublishLockedAsync(string subject, string? replyTo, NatsPooledBufferWriter payloadBuffer, NatsPooledBufferWriter? headersBuffer, CancellationToken cancellationToken) { - if (!lockHeld) - { - if (!await _semLock.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false)) - { - throw new TimeoutException(); - } - } - + var cancellationTimer = _ctPool.Start(_defaultCommandTimeout, cancellationToken); + await LockAsync(cancellationTimer.Token).ConfigureAwait(false); try { + var payload = payloadBuffer.WrittenMemory; + var headers = headersBuffer?.WrittenMemory; + if (_disposed) { throw new ObjectDisposedException(nameof(CommandWriter)); } - if (_flushTask is { IsCompletedSuccessfully: false }) - { - await _flushTask.WaitAsync(_defaultCommandTimeout, cancellationToken).ConfigureAwait(false); - } + _protocolWriter.WritePublish(_pipeWriter, subject, replyTo, headers, payload); - var success = false; - try + payloadBuffer.Reset(); + _pool.Return(payloadBuffer); + + if (headersBuffer != null) { - _protocolWriter.WriteUnsubscribe(sid, null); - success = true; + headersBuffer.Reset(); + _pool.Return(headersBuffer); } - finally + + var size = (int)_pipeWriter.UnflushedBytes; + _channelSize.Writer.TryWrite(size); + + var result = await _pipeWriter.FlushAsync(cancellationTimer.Token).ConfigureAwait(false); + if (result.IsCanceled) { - EnqueueCommand(success); + throw new OperationCanceledException(); } } finally { - _semLock.Release(); + await UnLockAsync().ConfigureAwait(false); + cancellationTimer.TryReturn(); } } - /// - /// Enqueues a command, and kicks off a flush - /// - /// - /// Whether the command was successful - /// If true, it will be sent on the wire - /// If false, it will be thrown out - /// - /// - /// Number of bytes to skip from beginning of message - /// when sending on the wire - /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void EnqueueCommand(bool success, int trim = 0) + private async ValueTask LockAsync(CancellationToken cancellationToken) { - var size = (int)_pipeWriter.UnflushedBytes; - if (size == 0) + Interlocked.Increment(ref _counter.PendingMessages); + try { - // no unflushed bytes means no command was produced - _flushTask = null; - return; + await _channelLock.Writer.WriteAsync(1, cancellationToken).ConfigureAwait(false); } - - if (success) + catch (TaskCanceledException) { - Interlocked.Add(ref _counter.PendingMessages, 1); + throw new OperationCanceledException(); + } + catch (ChannelClosedException) + { + throw new OperationCanceledException(); } - - _queuedCommandsWriter.TryWrite(new QueuedCommand(Size: size, Trim: success ? trim : size)); - var flush = _pipeWriter.FlushAsync(); - _flushTask = flush.IsCompletedSuccessfully ? null : flush.AsTask(); - } -} - -internal sealed class PriorityCommandWriter : IAsyncDisposable -{ - private readonly NatsPipeliningWriteProtocolProcessor _natsPipeliningWriteProtocolProcessor; - private int _disposed; - - public PriorityCommandWriter(ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action enqueuePing) - { - CommandWriter = new CommandWriter(opts, counter, enqueuePing, overrideCommandTimeout: TimeSpan.MaxValue); - _natsPipeliningWriteProtocolProcessor = CommandWriter.CreateNatsPipeliningWriteProtocolProcessor(socketConnection); } - public CommandWriter CommandWriter { get; } - - public async ValueTask DisposeAsync() + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private ValueTask UnLockAsync() { - if (Interlocked.Increment(ref _disposed) == 1) - { - // disposing command writer marks pipe writer as complete - await CommandWriter.DisposeAsync().ConfigureAwait(false); - try - { - // write loop will complete once pipe reader completes - await _natsPipeliningWriteProtocolProcessor.WriteLoop.ConfigureAwait(false); - } - finally - { - await _natsPipeliningWriteProtocolProcessor.DisposeAsync().ConfigureAwait(false); - } - } + Interlocked.Decrement(ref _counter.PendingMessages); + return _channelLock.Reader.ReadAsync(_cts.Token); } } diff --git a/src/NATS.Client.Core/Commands/NatsPooledBufferWriter.cs b/src/NATS.Client.Core/Commands/NatsPooledBufferWriter.cs new file mode 100644 index 00000000..b62d1c61 --- /dev/null +++ b/src/NATS.Client.Core/Commands/NatsPooledBufferWriter.cs @@ -0,0 +1,208 @@ +using System.Buffers; +using System.Numerics; +using System.Runtime.CompilerServices; +using NATS.Client.Core.Internal; + +namespace NATS.Client.Core.Commands; + +// adapted from https://github.com/CommunityToolkit/dotnet/blob/v8.2.2/src/CommunityToolkit.HighPerformance/Buffers/ArrayPoolBufferWriter%7BT%7D.cs +internal sealed class NatsPooledBufferWriter : IBufferWriter, IObjectPoolNode> +{ + private const int DefaultInitialMinBufferSize = 256; + private const int DefaultInitialMaxBufferSize = 65536; + + private readonly ArrayPool _pool; + private readonly int _size; + private T[]? _array; + private int _index; + private NatsPooledBufferWriter? _next; + + public NatsPooledBufferWriter(int size) + { + if (size < DefaultInitialMinBufferSize) + { + size = DefaultInitialMinBufferSize; + } + + if (size > DefaultInitialMaxBufferSize) + { + size = DefaultInitialMaxBufferSize; + } + + _size = size; + _pool = ArrayPool.Shared; + _array = _pool.Rent(size); + _index = 0; + } + + public ref NatsPooledBufferWriter? NextNode => ref _next; + + /// + /// Gets the data written to the underlying buffer so far, as a . + /// + public ReadOnlyMemory WrittenMemory + { + get + { + var array = _array; + + if (array is null) + { + ThrowObjectDisposedException(); + } + + return array!.AsMemory(0, _index); + } + } + + /// + /// Gets the data written to the underlying buffer so far, as a . + /// + public ReadOnlySpan WrittenSpan + { + get + { + var array = _array; + + if (array is null) + { + ThrowObjectDisposedException(); + } + + return array!.AsSpan(0, _index); + } + } + + /// + /// Gets the amount of data written to the underlying buffer so far. + /// + public int WrittenCount + { + get => _index; + } + + /// + public void Advance(int count) + { + var array = _array; + + if (array is null) + { + ThrowObjectDisposedException(); + } + + if (count < 0) + { + ThrowArgumentOutOfRangeExceptionForNegativeCount(); + } + + if (_index > array!.Length - count) + { + ThrowArgumentExceptionForAdvancedTooFar(); + } + + _index += count; + } + + /// + public Memory GetMemory(int sizeHint = 0) + { + CheckBufferAndEnsureCapacity(sizeHint); + + return _array.AsMemory(_index); + } + + /// + public Span GetSpan(int sizeHint = 0) + { + CheckBufferAndEnsureCapacity(sizeHint); + + return _array.AsSpan(_index); + } + + public void Reset() + { + if (_array != null) + _pool.Return(_array); + _array = _pool.Rent(_size); + _index = 0; + } + + /// + public override string ToString() + { + // See comments in MemoryOwner about this + if (typeof(T) == typeof(char) && + _array is char[] chars) + { + return new(chars, 0, _index); + } + + // Same representation used in Span + return $"NatsPooledBufferWriter<{typeof(T)}>[{_index}]"; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowArgumentOutOfRangeExceptionForNegativeCount() => throw new ArgumentOutOfRangeException("count", "The count can't be a negative value."); + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowArgumentOutOfRangeExceptionForNegativeSizeHint() => throw new ArgumentOutOfRangeException("sizeHint", "The size hint can't be a negative value."); + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowArgumentExceptionForAdvancedTooFar() => throw new ArgumentException("The buffer writer has advanced too far."); + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowObjectDisposedException() => throw new ObjectDisposedException("The current buffer has already been disposed."); + + /// + /// Ensures that has enough free space to contain a given number of new items. + /// + /// The minimum number of items to ensure space for in . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckBufferAndEnsureCapacity(int sizeHint) + { + var array = _array; + + if (array is null) + { + ThrowObjectDisposedException(); + } + + if (sizeHint < 0) + { + ThrowArgumentOutOfRangeExceptionForNegativeSizeHint(); + } + + if (sizeHint == 0) + { + sizeHint = 1; + } + + if (sizeHint > array!.Length - _index) + { + ResizeBuffer(sizeHint); + } + } + + /// + /// Resizes to ensure it can fit the specified number of new items. + /// + /// The minimum number of items to ensure space for in . + [MethodImpl(MethodImplOptions.NoInlining)] + private void ResizeBuffer(int sizeHint) + { + var minimumSize = (uint)_index + (uint)sizeHint; + + // The ArrayPool class has a maximum threshold of 1024 * 1024 for the maximum length of + // pooled arrays, and once this is exceeded it will just allocate a new array every time + // of exactly the requested size. In that case, we manually round up the requested size to + // the nearest power of two, to ensure that repeated consecutive writes when the array in + // use is bigger than that threshold don't end up causing a resize every single time. + if (minimumSize > 1024 * 1024) + { + minimumSize = BitOperations.RoundUpToPowerOf2(minimumSize); + } + + _pool.Resize(ref _array, (int)minimumSize); + } +} diff --git a/src/NATS.Client.Core/Commands/PriorityCommandWriter.cs b/src/NATS.Client.Core/Commands/PriorityCommandWriter.cs new file mode 100644 index 00000000..3020e84f --- /dev/null +++ b/src/NATS.Client.Core/Commands/PriorityCommandWriter.cs @@ -0,0 +1,25 @@ +using NATS.Client.Core.Internal; + +namespace NATS.Client.Core.Commands; + +internal sealed class PriorityCommandWriter : IAsyncDisposable +{ + private int _disposed; + + public PriorityCommandWriter(ObjectPool pool, ISocketConnection socketConnection, NatsOpts opts, ConnectionStatsCounter counter, Action enqueuePing) + { + CommandWriter = new CommandWriter(pool, opts, counter, enqueuePing, overrideCommandTimeout: Timeout.InfiniteTimeSpan); + CommandWriter.Reset(socketConnection); + } + + public CommandWriter CommandWriter { get; } + + public async ValueTask DisposeAsync() + { + if (Interlocked.Increment(ref _disposed) == 1) + { + // disposing command writer marks pipe writer as complete + await CommandWriter.DisposeAsync().ConfigureAwait(false); + } + } +} diff --git a/src/NATS.Client.Core/Commands/ProtocolWriter.cs b/src/NATS.Client.Core/Commands/ProtocolWriter.cs index fe7bd9e6..6d477c0d 100644 --- a/src/NATS.Client.Core/Commands/ProtocolWriter.cs +++ b/src/NATS.Client.Core/Commands/ProtocolWriter.cs @@ -1,6 +1,6 @@ +using System.Buffers; using System.Buffers.Binary; using System.Buffers.Text; -using System.IO.Pipelines; using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; @@ -36,47 +36,40 @@ internal sealed class ProtocolWriter private static readonly ulong PongNewLine = BinaryPrimitives.ReadUInt64LittleEndian("PONG\r\n "u8); private static readonly ulong UnsubSpace = BinaryPrimitives.ReadUInt64LittleEndian("UNSUB "u8); - private readonly PipeWriter _writer; - private readonly HeaderWriter _headerWriter; private readonly Encoding _subjectEncoding; - public ProtocolWriter(PipeWriter writer, Encoding subjectEncoding, Encoding headerEncoding) - { - _writer = writer; - _subjectEncoding = subjectEncoding; - _headerWriter = new HeaderWriter(writer, headerEncoding); - } + public ProtocolWriter(Encoding subjectEncoding) => _subjectEncoding = subjectEncoding; // https://docs.nats.io/reference/reference-protocols/nats-protocol#connect // CONNECT {["option_name":option_value],...} - public void WriteConnect(ClientOpts opts) + public void WriteConnect(IBufferWriter writer, ClientOpts opts) { - var span = _writer.GetSpan(UInt64Length); + var span = writer.GetSpan(UInt64Length); BinaryPrimitives.WriteUInt64LittleEndian(span, ConnectSpace); - _writer.Advance(ConnectSpaceLength); + writer.Advance(ConnectSpaceLength); - var jsonWriter = new Utf8JsonWriter(_writer); + var jsonWriter = new Utf8JsonWriter(writer); JsonSerializer.Serialize(jsonWriter, opts, JsonContext.Default.ClientOpts); - span = _writer.GetSpan(UInt16Length); + span = writer.GetSpan(UInt16Length); BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); - _writer.Advance(NewLineLength); + writer.Advance(NewLineLength); } // https://docs.nats.io/reference/reference-protocols/nats-protocol#ping-pong - public void WritePing() + public void WritePing(IBufferWriter writer) { - var span = _writer.GetSpan(UInt64Length); + var span = writer.GetSpan(UInt64Length); BinaryPrimitives.WriteUInt64LittleEndian(span, PingNewLine); - _writer.Advance(PingNewLineLength); + writer.Advance(PingNewLineLength); } // https://docs.nats.io/reference/reference-protocols/nats-protocol#ping-pong - public void WritePong() + public void WritePong(IBufferWriter writer) { - var span = _writer.GetSpan(UInt64Length); + var span = writer.GetSpan(UInt64Length); BinaryPrimitives.WriteUInt64LittleEndian(span, PongNewLine); - _writer.Advance(PongNewLineLength); + writer.Advance(PongNewLineLength); } // https://docs.nats.io/reference/reference-protocols/nats-protocol#pub @@ -84,124 +77,21 @@ public void WritePong() // or // https://docs.nats.io/reference/reference-protocols/nats-protocol#hpub // HPUB [reply-to] <#header bytes> <#total bytes>\r\n[headers]\r\n\r\n[payload]\r\n - // - // returns the number of bytes that should be skipped when writing to the wire - public int WritePublish(string subject, T? value, NatsHeaders? headers, string? replyTo, INatsSerialize serializer) + public void WritePublish(IBufferWriter writer, string subject, string? replyTo, ReadOnlyMemory? headers, ReadOnlyMemory payload) { - int ctrlLen; - if (headers == null) - { - // 'PUB ' + subject +' '+ payload len +'\r\n' - ctrlLen = PubSpaceLength + _subjectEncoding.GetByteCount(subject) + 1 + MaxIntStringLength + NewLineLength; - } - else - { - // 'HPUB ' + subject +' '+ header len +' '+ payload len +'\r\n' - ctrlLen = HpubSpaceLength + _subjectEncoding.GetByteCount(subject) + 1 + MaxIntStringLength + 1 + MaxIntStringLength + NewLineLength; - } - - if (replyTo != null) - { - // len += replyTo +' ' - ctrlLen += _subjectEncoding.GetByteCount(replyTo) + 1; - } - - var ctrlSpan = _writer.GetSpan(ctrlLen); - var span = ctrlSpan; - if (headers == null) - { - BinaryPrimitives.WriteUInt32LittleEndian(span, PubSpace); - span = span[PubSpaceLength..]; - } - else - { - BinaryPrimitives.WriteUInt64LittleEndian(span, HpubSpace); - span = span[HpubSpaceLength..]; - } - - var written = _subjectEncoding.GetBytes(subject, span); - span[written] = (byte)' '; - span = span[(written + 1)..]; - - if (replyTo != null) - { - written = _subjectEncoding.GetBytes(replyTo, span); - span[written] = (byte)' '; - span = span[(written + 1)..]; - } - - Span lenSpan; if (headers == null) { - // len = payload len - lenSpan = span[..MaxIntStringLength]; - span = span[lenSpan.Length..]; + WritePub(writer, subject, replyTo, payload); } else { - // len = header len +' '+ payload len - lenSpan = span[..(MaxIntStringLength + 1 + MaxIntStringLength)]; - span = span[lenSpan.Length..]; + WriteHpub(writer, subject, replyTo, headers.Value, payload); } - - BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); - _writer.Advance(ctrlLen); - - var headersLength = 0L; - var totalLength = 0L; - if (headers != null) - { - headersLength = _headerWriter.Write(headers); - totalLength += headersLength; - } - - // Consider null as empty payload. This way we are able to transmit null values as sentinels. - // Another point is serializer behaviour. For instance JSON serializer seems to serialize null - // as a string "null", others might throw exception. - if (value != null) - { - var initialCount = _writer.UnflushedBytes; - serializer.Serialize(_writer, value); - totalLength += _writer.UnflushedBytes - initialCount; - } - - span = _writer.GetSpan(UInt16Length); - BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); - _writer.Advance(NewLineLength); - - // write the length - var lenWritten = 0; - if (headers != null) - { - if (!Utf8Formatter.TryFormat(headersLength, lenSpan, out lenWritten)) - { - ThrowOnUtf8FormatFail(); - } - - lenSpan[lenWritten] = (byte)' '; - lenWritten += 1; - } - - if (!Utf8Formatter.TryFormat(totalLength, lenSpan[lenWritten..], out var tLen)) - { - ThrowOnUtf8FormatFail(); - } - - lenWritten += tLen; - var trim = lenSpan.Length - lenWritten; - if (trim > 0) - { - // shift right - ctrlSpan[..(ctrlLen - trim - NewLineLength)].CopyTo(ctrlSpan[trim..]); - ctrlSpan[..trim].Clear(); - } - - return trim; } // https://docs.nats.io/reference/reference-protocols/nats-protocol#sub // SUB [queue group] - public void WriteSubscribe(int sid, string subject, string? queueGroup, int? maxMsgs) + public void WriteSubscribe(IBufferWriter writer, int sid, string subject, string? queueGroup, int? maxMsgs) { // 'SUB ' + subject +' '+ sid +'\r\n' var ctrlLen = SubSpaceLength + _subjectEncoding.GetByteCount(subject) + 1 + MaxIntStringLength + NewLineLength; @@ -212,7 +102,7 @@ public void WriteSubscribe(int sid, string subject, string? queueGroup, int? max ctrlLen += _subjectEncoding.GetByteCount(queueGroup) + 1; } - var span = _writer.GetSpan(ctrlLen); + var span = writer.GetSpan(ctrlLen); BinaryPrimitives.WriteUInt32LittleEndian(span, SubSpace); var size = SubSpaceLength; span = span[SubSpaceLength..]; @@ -241,20 +131,20 @@ public void WriteSubscribe(int sid, string subject, string? queueGroup, int? max BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); size += NewLineLength; - _writer.Advance(size); + writer.Advance(size); // Immediately send UNSUB to minimize the risk of // receiving more messages than in case they are published // between our SUB and UNSUB calls. if (maxMsgs != null) { - WriteUnsubscribe(sid, maxMsgs); + WriteUnsubscribe(writer, sid, maxMsgs); } } // https://docs.nats.io/reference/reference-protocols/nats-protocol#unsub // UNSUB [max_msgs] - public void WriteUnsubscribe(int sid, int? maxMessages) + public void WriteUnsubscribe(IBufferWriter writer, int sid, int? maxMessages) { // 'UNSUB ' + sid +'\r\n' var ctrlLen = UnsubSpaceLength + MaxIntStringLength + NewLineLength; @@ -264,7 +154,7 @@ public void WriteUnsubscribe(int sid, int? maxMessages) ctrlLen += 1 + MaxIntStringLength; } - var span = _writer.GetSpan(ctrlLen); + var span = writer.GetSpan(ctrlLen); BinaryPrimitives.WriteUInt64LittleEndian(span, UnsubSpace); var size = UnsubSpaceLength; span = span[UnsubSpaceLength..]; @@ -291,10 +181,135 @@ public void WriteUnsubscribe(int sid, int? maxMessages) BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); size += NewLineLength; - _writer.Advance(size); + writer.Advance(size); } // optimization detailed here: https://github.com/nats-io/nats.net.v2/issues/320#issuecomment-1886165748 [MethodImpl(MethodImplOptions.NoInlining)] private static void ThrowOnUtf8FormatFail() => throw new NatsException("Can not format integer."); + + // PUB [reply-to] <#bytes>\r\n[payload]\r\n + private void WritePub(IBufferWriter writer, string subject, string? replyTo, ReadOnlyMemory payload) + { + Span spanPayloadLength = stackalloc byte[MaxIntStringLength]; + if (!Utf8Formatter.TryFormat(payload.Length, spanPayloadLength, out var payloadLengthWritten)) + { + ThrowOnUtf8FormatFail(); + } + + spanPayloadLength = spanPayloadLength.Slice(0, payloadLengthWritten); + + var total = PubSpaceLength; + + var subjectSpaceLength = _subjectEncoding.GetByteCount(subject) + 1; + total += subjectSpaceLength; + + var replyToLengthSpace = 0; + if (replyTo != null) + { + replyToLengthSpace = _subjectEncoding.GetByteCount(replyTo) + 1; + total += replyToLengthSpace; + } + + total += spanPayloadLength.Length + NewLineLength + payload.Length + NewLineLength; + + var span = writer.GetSpan(total); + + BinaryPrimitives.WriteUInt32LittleEndian(span, PubSpace); + span = span.Slice(PubSpaceLength); + + _subjectEncoding.GetBytes(subject, span); + span[subjectSpaceLength - 1] = (byte)' '; + span = span.Slice(subjectSpaceLength); + + if (replyTo != null) + { + _subjectEncoding.GetBytes(replyTo, span); + span[replyToLengthSpace - 1] = (byte)' '; + span = span.Slice(replyToLengthSpace); + } + + spanPayloadLength.CopyTo(span); + span = span.Slice(spanPayloadLength.Length); + + BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); + span = span.Slice(NewLineLength); + + payload.Span.CopyTo(span); + span = span.Slice(payload.Length); + + BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); + + writer.Advance(total); + } + + // HPUB [reply-to] <#header bytes> <#total bytes>\r\n[headers]\r\n\r\n[payload]\r\n + private void WriteHpub(IBufferWriter writer, string subject, string? replyTo, ReadOnlyMemory headers, ReadOnlyMemory payload) + { + Span spanPayloadLength = stackalloc byte[MaxIntStringLength]; + if (!Utf8Formatter.TryFormat(payload.Length + headers.Length, spanPayloadLength, out var payloadLengthWritten)) + { + ThrowOnUtf8FormatFail(); + } + + spanPayloadLength = spanPayloadLength.Slice(0, payloadLengthWritten); + + Span spanHeadersLength = stackalloc byte[MaxIntStringLength + 1]; + if (!Utf8Formatter.TryFormat(headers.Length, spanHeadersLength, out var headersLengthWritten)) + { + ThrowOnUtf8FormatFail(); + } + + spanHeadersLength = spanHeadersLength.Slice(0, headersLengthWritten + 1); + spanHeadersLength[headersLengthWritten] = (byte)' '; + + var total = HpubSpaceLength; + + var subjectSpaceLength = _subjectEncoding.GetByteCount(subject) + 1; + total += subjectSpaceLength; + + var replyToLengthSpace = 0; + if (replyTo != null) + { + replyToLengthSpace = _subjectEncoding.GetByteCount(replyTo) + 1; + total += replyToLengthSpace; + } + + total += spanHeadersLength.Length + spanPayloadLength.Length + NewLineLength + headers.Length + payload.Length + NewLineLength; + + var span = writer.GetSpan(total); + + BinaryPrimitives.WriteUInt64LittleEndian(span, HpubSpace); + span = span.Slice(HpubSpaceLength); + + _subjectEncoding.GetBytes(subject, span); + span[subjectSpaceLength - 1] = (byte)' '; + span = span.Slice(subjectSpaceLength); + + if (replyTo != null) + { + _subjectEncoding.GetBytes(replyTo, span); + span[replyToLengthSpace - 1] = (byte)' '; + span = span.Slice(replyToLengthSpace); + } + + spanHeadersLength.CopyTo(span); + span = span.Slice(spanHeadersLength.Length); + + spanPayloadLength.CopyTo(span); + span = span.Slice(spanPayloadLength.Length); + + BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); + span = span.Slice(NewLineLength); + + headers.Span.CopyTo(span); + span = span.Slice(headers.Length); + + payload.Span.CopyTo(span); + span = span.Slice(payload.Length); + + BinaryPrimitives.WriteUInt16LittleEndian(span, NewLine); + + writer.Advance(total); + } } diff --git a/src/NATS.Client.Core/Internal/CancellationTimer.cs b/src/NATS.Client.Core/Internal/CancellationTimer.cs index 537184ec..6fe7b280 100644 --- a/src/NATS.Client.Core/Internal/CancellationTimer.cs +++ b/src/NATS.Client.Core/Internal/CancellationTimer.cs @@ -66,7 +66,11 @@ public static CancellationTimer Start(ObjectPool pool, CancellationToken rootTok } self._timeout = timeout; - self._cancellationTokenSource.CancelAfter(timeout); + if (timeout != Timeout.InfiniteTimeSpan) + { + self._cancellationTokenSource.CancelAfter(timeout); + } + return self; } diff --git a/src/NATS.Client.Core/Internal/HeaderWriter.cs b/src/NATS.Client.Core/Internal/HeaderWriter.cs index 47b0d8c2..8bb62694 100644 --- a/src/NATS.Client.Core/Internal/HeaderWriter.cs +++ b/src/NATS.Client.Core/Internal/HeaderWriter.cs @@ -12,23 +12,19 @@ internal class HeaderWriter private const byte ByteColon = (byte)':'; private const byte ByteSpace = (byte)' '; private const byte ByteDel = 127; - private readonly PipeWriter _pipeWriter; + private readonly Encoding _encoding; - public HeaderWriter(PipeWriter pipeWriter, Encoding encoding) - { - _pipeWriter = pipeWriter; - _encoding = encoding; - } + public HeaderWriter(Encoding encoding) => _encoding = encoding; private static ReadOnlySpan CrLf => new[] { ByteCr, ByteLf }; private static ReadOnlySpan ColonSpace => new[] { ByteColon, ByteSpace }; - internal long Write(NatsHeaders headers) + internal long Write(IBufferWriter bufferWriter, NatsHeaders headers) { - var initialCount = _pipeWriter.UnflushedBytes; - _pipeWriter.WriteSpan(CommandConstants.NatsHeaders10NewLine); + bufferWriter.WriteSpan(CommandConstants.NatsHeaders10NewLine); + var len = CommandConstants.NatsHeaders10NewLine.Length; foreach (var kv in headers) { @@ -38,7 +34,7 @@ internal long Write(NatsHeaders headers) { // write key var keyLength = _encoding.GetByteCount(kv.Key); - var keySpan = _pipeWriter.GetSpan(keyLength); + var keySpan = bufferWriter.GetSpan(keyLength); _encoding.GetBytes(kv.Key, keySpan); if (!ValidateKey(keySpan.Slice(0, keyLength))) { @@ -46,20 +42,26 @@ internal long Write(NatsHeaders headers) $"Invalid header key '{kv.Key}': contains colon, space, or other non-printable ASCII characters"); } - _pipeWriter.Advance(keyLength); - _pipeWriter.Write(ColonSpace); + bufferWriter.Advance(keyLength); + len += keyLength; + + bufferWriter.Write(ColonSpace); + len += ColonSpace.Length; // write values var valueLength = _encoding.GetByteCount(value); - var valueSpan = _pipeWriter.GetSpan(valueLength); + var valueSpan = bufferWriter.GetSpan(valueLength); _encoding.GetBytes(value, valueSpan); if (!ValidateValue(valueSpan.Slice(0, valueLength))) { throw new NatsException($"Invalid header value for key '{kv.Key}': contains CRLF"); } - _pipeWriter.Advance(valueLength); - _pipeWriter.Write(CrLf); + bufferWriter.Advance(valueLength); + len += valueLength; + + bufferWriter.Write(CrLf); + len += CrLf.Length; } } } @@ -67,9 +69,10 @@ internal long Write(NatsHeaders headers) // Even empty header needs to terminate. // We will send NATS/1.0 version line // even if there are no headers. - _pipeWriter.Write(CrLf); + bufferWriter.Write(CrLf); + len += CrLf.Length; - return _pipeWriter.UnflushedBytes - initialCount; + return len; } // cannot contain ASCII Bytes <=32, 58, or 127 diff --git a/src/NATS.Client.Core/Internal/NatsPipeliningWriteProtocolProcessor.cs b/src/NATS.Client.Core/Internal/NatsPipeliningWriteProtocolProcessor.cs index 50661ada..922cb5b9 100644 --- a/src/NATS.Client.Core/Internal/NatsPipeliningWriteProtocolProcessor.cs +++ b/src/NATS.Client.Core/Internal/NatsPipeliningWriteProtocolProcessor.cs @@ -1,320 +1,309 @@ -using System.Buffers; -using System.Diagnostics; -using System.IO.Pipelines; -using System.Threading.Channels; -using Microsoft.Extensions.Logging; -using NATS.Client.Core.Commands; - -namespace NATS.Client.Core.Internal; - -internal sealed class NatsPipeliningWriteProtocolProcessor : IAsyncDisposable -{ - private readonly CancellationTokenSource _cts; - private readonly ConnectionStatsCounter _counter; - private readonly NatsOpts _opts; - private readonly PipeReader _pipeReader; - private readonly Queue _inFlightCommands; - private readonly ChannelReader _queuedCommandReader; - private readonly ISocketConnection _socketConnection; - private readonly Stopwatch _stopwatch = new Stopwatch(); - private int _disposed; - - public NatsPipeliningWriteProtocolProcessor(ISocketConnection socketConnection, CommandWriter commandWriter, NatsOpts opts, ConnectionStatsCounter counter) - { - _cts = new CancellationTokenSource(); - _counter = counter; - _inFlightCommands = commandWriter.InFlightCommands; - _opts = opts; - _pipeReader = commandWriter.PipeReader; - _queuedCommandReader = commandWriter.QueuedCommandsReader; - _socketConnection = socketConnection; - WriteLoop = Task.Run(WriteLoopAsync); - } - - public Task WriteLoop { get; } - - public async ValueTask DisposeAsync() - { - if (Interlocked.Increment(ref _disposed) == 1) - { -#if NET6_0 - _cts.Cancel(); -#else - await _cts.CancelAsync().ConfigureAwait(false); -#endif - try - { - await WriteLoop.ConfigureAwait(false); // wait to drain writer - } - catch - { - // ignore - } - } - } - - private async Task WriteLoopAsync() - { - var logger = _opts.LoggerFactory.CreateLogger(); - var isEnabledTraceLogging = logger.IsEnabled(LogLevel.Trace); - var cancellationToken = _cts.Token; - var pending = 0; - var trimming = 0; - var examinedOffset = 0; - var sent = 0; - - // memory segment used to consolidate multiple small memory chunks - // should <= (minimumSegmentSize * 0.5) in CommandWriter - // 8520 should fit into 6 packets on 1500 MTU TLS connection or 1 packet on 9000 MTU TLS connection - // assuming 40 bytes TCP overhead + 40 bytes TLS overhead per packet - var consolidateMem = new Memory(new byte[8520]); - - // add up in flight command sum - var inFlightSum = 0; - foreach (var command in _inFlightCommands) - { - inFlightSum += command.Size; - } - - try - { - while (true) - { - // read data from pipe reader - var result = await _pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false); - if (result.IsCanceled) - { - // if the pipe has been canceled, break - break; - } - - var consumedPos = result.Buffer.Start; - var examinedPos = result.Buffer.Start; - try - { - // move from _queuedCommandReader to _inFlightCommands until the total size - // of all _inFlightCommands is >= result.Buffer.Length - while (inFlightSum < result.Buffer.Length) - { - QueuedCommand queuedCommand; - while (!_queuedCommandReader.TryRead(out queuedCommand)) - { - await _queuedCommandReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false); - } - - _inFlightCommands.Enqueue(queuedCommand); - inFlightSum += queuedCommand.Size; - } - - // examinedOffset was processed last iteration, so slice it off the buffer - var buffer = result.Buffer.Slice(examinedOffset); - - // iterate until buffer is empty - // any time buffer sliced and re-assigned, continue should be called - // so that this conditional is checked - while (buffer.Length > 0) - { - // if there are no pending bytes to send, set to next command - if (pending == 0) - { - var peek = _inFlightCommands.Peek(); - pending = peek.Size; - trimming = peek.Trim; - } - - // from this point forward, pending != 0 - // any operation that decrements pending should check if it is 0, - // and dequeue from _inFlightCommands if it is - - // trim any bytes that should not be sent - if (trimming > 0) - { - var trimmed = Math.Min(trimming, (int)buffer.Length); - pending -= trimmed; - trimming -= trimmed; - if (pending == 0) - { - // the entire command was trimmed (canceled) - inFlightSum -= _inFlightCommands.Dequeue().Size; - consumedPos = buffer.GetPosition(trimmed); - examinedPos = consumedPos; - examinedOffset = 0; - buffer = buffer.Slice(trimmed); - - // iterate in case buffer is now empty - continue; - } - - // the command was partially trimmed - examinedPos = buffer.GetPosition(trimmed); - examinedOffset += trimmed; - buffer = buffer.Slice(trimmed); - - // iterate in case buffer is now empty - continue; - } - - if (sent > 0) - { - if (pending <= sent) - { - // the entire command was sent - inFlightSum -= _inFlightCommands.Dequeue().Size; - Interlocked.Add(ref _counter.PendingMessages, -1); - Interlocked.Add(ref _counter.SentMessages, 1); - - // mark the bytes as consumed, and reset pending - sent -= pending; - consumedPos = buffer.GetPosition(pending); - examinedPos = consumedPos; - examinedOffset = 0; - buffer = buffer.Slice(pending); - pending = 0; - - // iterate in case buffer is now empty - continue; - } - - // the command was partially sent - // decrement pending by the number of bytes that were sent - pending -= sent; - examinedPos = buffer.GetPosition(sent); - examinedOffset += sent; - buffer = buffer.Slice(sent); - sent = 0; - - // iterate in case buffer is now empty - continue; - } - - // loop through _inFlightCommands to determine whether any commands - // in the first memory segment need trimming - var sendMem = buffer.First; - var maxSize = 0; - var maxSizeCap = Math.Max(sendMem.Length, consolidateMem.Length); - var doTrim = false; - foreach (var command in _inFlightCommands) - { - if (maxSize == 0) - { - // first command; set to pending - maxSize = pending; - continue; - } - - if (maxSize > maxSizeCap) - { - // over cap - break; - } - - if (command.Trim > 0) - { - // will have to trim - doTrim = true; - break; - } - - maxSize += command.Size; - } - - // adjust the first memory segment to end on a command boundary - if (sendMem.Length > maxSize) - { - sendMem = sendMem[..maxSize]; - } - - // if trimming is required or the first memory segment is smaller than consolidateMem - // consolidate bytes that need to be sent into consolidateMem - if (doTrim || (buffer.Length > sendMem.Length && sendMem.Length < consolidateMem.Length)) - { - var bufferIter = buffer; - var memIter = consolidateMem; - var trimmedSize = 0; - foreach (var command in _inFlightCommands) - { - if (bufferIter.Length == 0 || memIter.Length == 0) - { - break; - } - - int write; - if (trimmedSize == 0) - { - // first command, only write pending data - write = pending; - } - else if (command.Trim == 0) - { - write = command.Size; - } - else - { - if (bufferIter.Length < command.Trim + 1) - { - // not enough bytes to start writing the next command - break; - } - - bufferIter = bufferIter.Slice(command.Trim); - write = command.Size - command.Trim; - if (write == 0) - { - // the entire command was trimmed (canceled) - continue; - } - } - - write = Math.Min(memIter.Length, write); - write = Math.Min((int)bufferIter.Length, write); - bufferIter.Slice(0, write).CopyTo(memIter.Span); - memIter = memIter[write..]; - bufferIter = bufferIter.Slice(write); - trimmedSize += write; - } - - sendMem = consolidateMem[..trimmedSize]; - } - - // perform send - _stopwatch.Restart(); - sent = await _socketConnection.SendAsync(sendMem).ConfigureAwait(false); - _stopwatch.Stop(); - Interlocked.Add(ref _counter.SentBytes, sent); - if (isEnabledTraceLogging) - { - logger.LogTrace("Socket.SendAsync. Size: {0} Elapsed: {1}ms", sent, _stopwatch.Elapsed.TotalMilliseconds); - } - } - } - finally - { - // _pipeReader.AdvanceTo must be called exactly once for every - // _pipeReader.ReadAsync, which is why it is in the finally block - _pipeReader.AdvanceTo(consumedPos, examinedPos); - } - - if (result.IsCompleted) - { - // if the pipe has been completed, break - break; - } - } - } - catch (OperationCanceledException) - { - // ignore, intentionally disposed - } - catch (SocketClosedException) - { - // ignore, will be handled in read loop - } - catch (Exception ex) - { - logger.LogError(ex, "Unexpected error occured in write loop"); - throw; - } - - logger.LogDebug("WriteLoop finished."); - } -} +// using System.Buffers; +// using System.Diagnostics; +// using System.IO.Pipelines; +// using System.Threading.Channels; +// using Microsoft.Extensions.Logging; +// using NATS.Client.Core.Commands; +// +// namespace NATS.Client.Core.Internal; +// +// internal sealed class NatsPipeliningWriteProtocolProcessor : IAsyncDisposable +// { +// private readonly CancellationTokenSource _cts; +// private readonly ConnectionStatsCounter _counter; +// private readonly NatsOpts _opts; +// private readonly PipeReader _pipeReader; +// private readonly Queue _inFlightCommands; +// private readonly ChannelReader _queuedCommandReader; +// private readonly ISocketConnection _socketConnection; +// private readonly Stopwatch _stopwatch = new Stopwatch(); +// private int _disposed; +// +// public NatsPipeliningWriteProtocolProcessor(ISocketConnection socketConnection, CommandWriter commandWriter, NatsOpts opts, ConnectionStatsCounter counter) +// { +// _cts = new CancellationTokenSource(); +// _counter = counter; +// _inFlightCommands = commandWriter.InFlightCommands; +// _opts = opts; +// _pipeReader = commandWriter.PipeReader; +// _queuedCommandReader = commandWriter.QueuedCommandsReader; +// _socketConnection = socketConnection; +// WriteLoop = Task.Run(WriteLoopAsync); +// } +// +// public Task WriteLoop { get; } +// +// public async ValueTask DisposeAsync() +// { +// if (Interlocked.Increment(ref _disposed) == 1) +// { +// #if NET6_0 +// _cts.Cancel(); +// #else +// await _cts.CancelAsync().ConfigureAwait(false); +// #endif +// try +// { +// await WriteLoop.ConfigureAwait(false); // wait to drain writer +// } +// catch +// { +// // ignore +// } +// } +// } +// +// private async Task WriteLoopAsync() +// { +// var logger = _opts.LoggerFactory.CreateLogger(); +// var isEnabledTraceLogging = logger.IsEnabled(LogLevel.Trace); +// var cancellationToken = _cts.Token; +// var pending = 0; +// var trimming = 0; +// var examinedOffset = 0; +// +// // memory segment used to consolidate multiple small memory chunks +// // should <= (minimumSegmentSize * 0.5) in CommandWriter +// // 8520 should fit into 6 packets on 1500 MTU TLS connection or 1 packet on 9000 MTU TLS connection +// // assuming 40 bytes TCP overhead + 40 bytes TLS overhead per packet +// var consolidateMem = new Memory(new byte[8520]); +// +// // add up in flight command sum +// var inFlightSum = 0; +// foreach (var command in _inFlightCommands) +// { +// inFlightSum += command.Size; +// } +// +// try +// { +// while (true) +// { +// var result = await _pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false); +// if (result.IsCanceled) +// { +// break; +// } +// +// var consumedPos = result.Buffer.Start; +// var examinedPos = result.Buffer.Start; +// try +// { +// var buffer = result.Buffer.Slice(examinedOffset); +// while (inFlightSum < result.Buffer.Length) +// { +// QueuedCommand queuedCommand; +// while (!_queuedCommandReader.TryRead(out queuedCommand)) +// { +// await _queuedCommandReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false); +// } +// +// _inFlightCommands.Enqueue(queuedCommand); +// inFlightSum += queuedCommand.Size; +// } +// +// while (buffer.Length > 0) +// { +// if (pending == 0) +// { +// var peek = _inFlightCommands.Peek(); +// pending = peek.Size; +// trimming = peek.Trim; +// } +// +// if (trimming > 0) +// { +// var trimmed = Math.Min(trimming, (int)buffer.Length); +// consumedPos = buffer.GetPosition(trimmed); +// examinedPos = buffer.GetPosition(trimmed); +// examinedOffset = 0; +// buffer = buffer.Slice(trimmed); +// pending -= trimmed; +// trimming -= trimmed; +// if (pending == 0) +// { +// // the entire command was trimmed (canceled) +// inFlightSum -= _inFlightCommands.Dequeue().Size; +// } +// +// continue; +// } +// +// var sendMem = buffer.First; +// var maxSize = 0; +// var maxSizeCap = Math.Max(sendMem.Length, consolidateMem.Length); +// var doTrim = false; +// foreach (var command in _inFlightCommands) +// { +// if (maxSize == 0) +// { +// // first command; set to pending +// maxSize = pending; +// continue; +// } +// +// if (maxSize > maxSizeCap) +// { +// // over cap +// break; +// } +// +// if (command.Trim > 0) +// { +// // will have to trim +// doTrim = true; +// break; +// } +// +// maxSize += command.Size; +// } +// +// if (sendMem.Length > maxSize) +// { +// sendMem = sendMem[..maxSize]; +// } +// +// var bufferIter = buffer; +// if (doTrim || (bufferIter.Length > sendMem.Length && sendMem.Length < consolidateMem.Length)) +// { +// var memIter = consolidateMem; +// var trimmedSize = 0; +// foreach (var command in _inFlightCommands) +// { +// if (bufferIter.Length == 0 || memIter.Length == 0) +// { +// break; +// } +// +// int write; +// if (trimmedSize == 0) +// { +// // first command, only write pending data +// write = pending; +// } +// else if (command.Trim == 0) +// { +// write = command.Size; +// } +// else +// { +// if (bufferIter.Length < command.Trim + 1) +// { +// // not enough bytes to start writing the next command +// break; +// } +// +// bufferIter = bufferIter.Slice(command.Trim); +// write = command.Size - command.Trim; +// if (write == 0) +// { +// // the entire command was trimmed (canceled) +// continue; +// } +// } +// +// write = Math.Min(memIter.Length, write); +// write = Math.Min((int)bufferIter.Length, write); +// bufferIter.Slice(0, write).CopyTo(memIter.Span); +// memIter = memIter[write..]; +// bufferIter = bufferIter.Slice(write); +// trimmedSize += write; +// } +// +// sendMem = consolidateMem[..trimmedSize]; +// } +// +// // perform send +// _stopwatch.Restart(); +// var sent = await _socketConnection.SendAsync(sendMem).ConfigureAwait(false); +// _stopwatch.Stop(); +// Interlocked.Add(ref _counter.SentBytes, sent); +// if (isEnabledTraceLogging) +// { +// logger.LogTrace("Socket.SendAsync. Size: {0} Elapsed: {1}ms", sent, _stopwatch.Elapsed.TotalMilliseconds); +// } +// +// var consumed = 0; +// var sentAndTrimmed = sent; +// while (consumed < sentAndTrimmed) +// { +// if (pending == 0) +// { +// var peek = _inFlightCommands.Peek(); +// pending = peek.Size - peek.Trim; +// consumed += peek.Trim; +// sentAndTrimmed += peek.Trim; +// +// if (pending == 0) +// { +// // the entire command was trimmed (canceled) +// inFlightSum -= _inFlightCommands.Dequeue().Size; +// continue; +// } +// } +// +// if (pending <= sentAndTrimmed - consumed) +// { +// // the entire command was sent +// inFlightSum -= _inFlightCommands.Dequeue().Size; +// Interlocked.Add(ref _counter.PendingMessages, -1); +// Interlocked.Add(ref _counter.SentMessages, 1); +// +// // mark the bytes as consumed, and reset pending +// consumed += pending; +// pending = 0; +// } +// else +// { +// // the entire command was not sent; decrement pending by +// // the number of bytes from the command that was sent +// pending += consumed - sentAndTrimmed; +// break; +// } +// } +// +// if (consumed > 0) +// { +// // mark fully sent commands as consumed +// consumedPos = buffer.GetPosition(consumed); +// examinedOffset = sentAndTrimmed - consumed; +// } +// else +// { +// // no commands were consumed +// examinedOffset += sentAndTrimmed; +// } +// +// // lop off sent bytes for next iteration +// examinedPos = buffer.GetPosition(sentAndTrimmed); +// buffer = buffer.Slice(sentAndTrimmed); +// } +// } +// finally +// { +// _pipeReader.AdvanceTo(consumedPos, examinedPos); +// } +// +// if (result.IsCompleted) +// { +// break; +// } +// } +// } +// catch (OperationCanceledException) +// { +// // ignore, intentionally disposed +// } +// catch (SocketClosedException) +// { +// // ignore, will be handled in read loop +// } +// catch (Exception ex) +// { +// logger.LogError(ex, "Unexpected error occured in write loop"); +// throw; +// } +// +// logger.LogDebug("WriteLoop finished."); +// } +// } diff --git a/src/NATS.Client.Core/NatsConnection.cs b/src/NATS.Client.Core/NatsConnection.cs index 6a6236f2..0a285d28 100644 --- a/src/NATS.Client.Core/NatsConnection.cs +++ b/src/NATS.Client.Core/NatsConnection.cs @@ -56,7 +56,6 @@ public partial class NatsConnection : INatsConnection private volatile NatsUri? _currentConnectUri; private volatile NatsUri? _lastSeedConnectUri; private NatsReadProtocolProcessor? _socketReader; - private NatsPipeliningWriteProtocolProcessor? _socketWriter; private TaskCompletionSource _waitForOpenConnection; private TlsCerts? _tlsCerts; private UserCredentials? _userCredentials; @@ -80,7 +79,7 @@ public NatsConnection(NatsOpts opts) _cancellationTimerPool = new CancellationTimerPool(_pool, _disposedCancellationTokenSource.Token); _name = opts.Name; Counter = new ConnectionStatsCounter(); - CommandWriter = new CommandWriter(Opts, Counter, EnqueuePing); + CommandWriter = new CommandWriter(_pool, Opts, Counter, EnqueuePing); InboxPrefix = NewInbox(opts.InboxPrefix); SubscriptionManager = new SubscriptionManager(this, InboxPrefix); _logger = opts.LoggerFactory.CreateLogger(); @@ -218,7 +217,8 @@ internal ValueTask UnsubscribeAsync(int sid) { try { - return CommandWriter.UnsubscribeAsync(sid, CancellationToken.None); + // TODO: use maxMsgs in INatsSub to unsubscribe. + return CommandWriter.UnsubscribeAsync(sid, null, CancellationToken.None); } catch (Exception ex) { @@ -431,7 +431,7 @@ private async ValueTask SetupReaderWriterAsync(bool reconnect) // Authentication _userCredentials?.Authenticate(_clientOpts, WritableServerInfo); - await using (var priorityCommandWriter = new PriorityCommandWriter(_socket!, Opts, Counter, EnqueuePing)) + await using (var priorityCommandWriter = new PriorityCommandWriter(_pool, _socket!, Opts, Counter, EnqueuePing)) { // add CONNECT and PING command to priority lane await priorityCommandWriter.CommandWriter.ConnectAsync(_clientOpts, CancellationToken.None).ConfigureAwait(false); @@ -455,7 +455,7 @@ await using (var priorityCommandWriter = new PriorityCommandWriter(_socket!, Opt } // create the socket writer - _socketWriter = CommandWriter.CreateNatsPipeliningWriteProtocolProcessor(_socket!); + CommandWriter.Reset(_socket!); lock (_gate) { @@ -500,6 +500,8 @@ private async void ReconnectLoop() // If dispose this client, WaitForClosed throws OperationCanceledException so stop reconnect-loop correctly. await _socket!.WaitForClosed.ConfigureAwait(false); + await CommandWriter.CancelReaderLoopAsync().ConfigureAwait(false); + _logger.LogTrace(NatsLogEvents.Connection, "Connection {Name} is closed. Will cleanup and reconnect", _name); lock (_gate) { @@ -800,13 +802,6 @@ private async ValueTask DisposeSocketComponentAsync(IAsyncDisposable component, // Dispose Reader(Drain read buffers but no reads more) private async ValueTask DisposeSocketAsync(bool asyncReaderDispose) { - // writer's internal buffer/channel is not thread-safe, must wait until complete. - if (_socketWriter != null) - { - await DisposeSocketComponentAsync(_socketWriter, "socket writer").ConfigureAwait(false); - _socketWriter = null; - } - if (_socket != null) { await DisposeSocketComponentAsync(_socket, "socket").ConfigureAwait(false); diff --git a/src/NATS.Client.Core/NatsLogEvents.cs b/src/NATS.Client.Core/NatsLogEvents.cs index fc1cd0fe..15456e79 100644 --- a/src/NATS.Client.Core/NatsLogEvents.cs +++ b/src/NATS.Client.Core/NatsLogEvents.cs @@ -11,4 +11,5 @@ public static class NatsLogEvents public static readonly EventId Protocol = new(1005, nameof(Protocol)); public static readonly EventId TcpSocket = new(1006, nameof(TcpSocket)); public static readonly EventId Internal = new(1006, nameof(Internal)); + public static readonly EventId Buffer = new(1007, nameof(Buffer)); } diff --git a/src/NATS.Client.Core/NatsOpts.cs b/src/NATS.Client.Core/NatsOpts.cs index 4056427e..dc722240 100644 --- a/src/NATS.Client.Core/NatsOpts.cs +++ b/src/NATS.Client.Core/NatsOpts.cs @@ -31,7 +31,9 @@ public sealed record NatsOpts public ILoggerFactory LoggerFactory { get; init; } = NullLoggerFactory.Instance; - public int WriterBufferSize { get; init; } = 1048576; + // Same as default pipelines pause writer size + // Performing better compared to nats bench on localhost + public int WriterBufferSize { get; init; } = 65536; public int ReaderBufferSize { get; init; } = 1048576; diff --git a/tests/NATS.Client.Core.Tests/CancellationTest.cs b/tests/NATS.Client.Core.Tests/CancellationTest.cs index 1148cd56..732c07f8 100644 --- a/tests/NATS.Client.Core.Tests/CancellationTest.cs +++ b/tests/NATS.Client.Core.Tests/CancellationTest.cs @@ -15,15 +15,18 @@ public async Task CommandTimeoutTest() await using var conn = server.CreateClientConnection(NatsOpts.Default with { CommandTimeout = TimeSpan.FromMilliseconds(1) }); await conn.ConnectAsync(); + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + var cancellationToken = cts.Token; + // stall the flush task - await conn.CommandWriter.TestStallFlushAsync(TimeSpan.FromSeconds(5)); + Assert.True(conn.CommandWriter.TestStallFlush()); // commands that call ConnectAsync throw OperationCanceledException - await Assert.ThrowsAsync(() => conn.PingAsync().AsTask()); - await Assert.ThrowsAsync(() => conn.PublishAsync("test").AsTask()); - await Assert.ThrowsAsync(async () => + await Assert.ThrowsAsync(() => conn.PingAsync(cancellationToken).AsTask()); + await Assert.ThrowsAsync(() => conn.PublishAsync("test", cancellationToken: cancellationToken).AsTask()); + await Assert.ThrowsAsync(async () => { - await foreach (var unused in conn.SubscribeAsync("test")) + await foreach (var unused in conn.SubscribeAsync("test", cancellationToken: cancellationToken)) { } }); @@ -75,4 +78,22 @@ await foreach (var unused in conn.SubscribeAsync("test", cancellationTok } }); } + + [Fact] + public async Task Cancellation_timer() + { + var objectPool = new ObjectPool(10); + var cancellationTimerPool = new CancellationTimerPool(objectPool, CancellationToken.None); + var cancellationTimer = cancellationTimerPool.Start(TimeSpan.FromSeconds(2), CancellationToken.None); + + try + { + await Task.Delay(TimeSpan.FromSeconds(4), cancellationTimer.Token); + _output.WriteLine($"delayed 4 seconds"); + } + catch (Exception e) + { + _output.WriteLine($"Exception: {e.GetType().Name}"); + } + } } diff --git a/tests/NATS.Client.Core.Tests/ConnectionRetryTest.cs b/tests/NATS.Client.Core.Tests/ConnectionRetryTest.cs index 261337c6..132b727c 100644 --- a/tests/NATS.Client.Core.Tests/ConnectionRetryTest.cs +++ b/tests/NATS.Client.Core.Tests/ConnectionRetryTest.cs @@ -58,6 +58,7 @@ public async Task Retry_and_connect_after_disconnected() [Fact] public async Task Reconnect_doesnt_drop_partially_sent_msgs() { + const int msgSize = 1048576; // 1MiB await using var server = NatsServer.Start(); await using var pubConn = server.CreateClientConnection(); @@ -86,6 +87,7 @@ await foreach (var msg in sub.Msgs.ReadAllAsync(timeoutCts.Token)) } else { + Assert.Equal(msgSize, msg.Data.Length); Interlocked.Increment(ref received); } } @@ -99,7 +101,7 @@ await foreach (var msg in sub.Msgs.ReadAllAsync(timeoutCts.Token)) } var sent = 0; - var data = new byte[1048576]; // 1MiB + var data = new byte[msgSize]; var sendTask = Task.Run(async () => { while (!stopCts.IsCancellationRequested) @@ -131,6 +133,8 @@ await foreach (var msg in sub.Msgs.ReadAllAsync(timeoutCts.Token)) Assert.True(reconnects > 0, "connection did not reconnect"); Assert.True(received <= sent, $"duplicate messages sent on wire- {sent} sent, {received} received"); + _output.WriteLine($"reconnects: {reconnects}, sent: {sent}, received: {received}"); + // some messages may still be lost, as socket could have been disconnected // after socket.WriteAsync returned, but before OS sent // check to ensure that the loss was < 1% diff --git a/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs b/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs index 6b74776a..e2081e74 100644 --- a/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs +++ b/tests/NATS.Client.Core.Tests/NatsHeaderTest.cs @@ -21,8 +21,8 @@ public async Task WriterTests() ["key"] = "a-long-header-value", }; var pipe = new Pipe(new PipeOptions(pauseWriterThreshold: 0)); - var writer = new HeaderWriter(pipe.Writer, Encoding.UTF8); - var written = writer.Write(headers); + var writer = new HeaderWriter(Encoding.UTF8); + var written = writer.Write(pipe.Writer, headers); var text = "NATS/1.0\r\nk1: v1\r\nk2: v2-0\r\nk2: v2-1\r\na-long-header-key: value\r\nkey: a-long-header-value\r\n\r\n"; var expected = new ReadOnlySequence(Encoding.UTF8.GetBytes(text)); @@ -30,7 +30,24 @@ public async Task WriterTests() Assert.Equal(expected.Length, written); await pipe.Writer.FlushAsync(); var result = await pipe.Reader.ReadAtLeastAsync((int)written); + Assert.True(expected.ToSpan().SequenceEqual(result.Buffer.ToSpan())); + _output.WriteLine($"Buffer:\n{result.Buffer.FirstSpan.Dump()}"); + } + [Fact] + public async Task WriterEmptyTests() + { + var headers = new NatsHeaders(); + var pipe = new Pipe(new PipeOptions(pauseWriterThreshold: 0)); + var writer = new HeaderWriter(Encoding.UTF8); + var written = writer.Write(pipe.Writer, headers); + + var text = "NATS/1.0\r\n\r\n"; + var expected = new ReadOnlySequence(Encoding.UTF8.GetBytes(text)); + + Assert.Equal(expected.Length, written); + await pipe.Writer.FlushAsync(); + var result = await pipe.Reader.ReadAtLeastAsync((int)written); Assert.True(expected.ToSpan().SequenceEqual(result.Buffer.ToSpan())); _output.WriteLine($"Buffer:\n{result.Buffer.FirstSpan.Dump()}"); } diff --git a/tests/NATS.Client.Core.Tests/ProtocolTest.cs b/tests/NATS.Client.Core.Tests/ProtocolTest.cs index 35a64e26..6c7228c3 100644 --- a/tests/NATS.Client.Core.Tests/ProtocolTest.cs +++ b/tests/NATS.Client.Core.Tests/ProtocolTest.cs @@ -1,4 +1,5 @@ using System.Buffers; +using System.Collections.Concurrent; using System.Text; using Microsoft.Extensions.Logging; using NATS.Client.TestUtilities; @@ -335,17 +336,29 @@ public async Task Reconnect_with_sub_and_additional_commands() () => proxy.ClientFrames.Any(f => f.Message.StartsWith("PUB foo"))); var frames = proxy.ClientFrames.Select(f => f.Message).ToList(); + + foreach (var frame in frames) + { + _output.WriteLine($"frame: {frame}"); + } + Assert.StartsWith("SUB foo", frames[0]); - Assert.StartsWith("PUB bar1", frames[1]); - Assert.StartsWith("PUB bar2", frames[2]); - Assert.StartsWith("PUB bar3", frames[3]); - Assert.StartsWith("PUB foo", frames[4]); + + for (var i = 0; i < 100; i++) + { + Assert.StartsWith($"PUB bar{i}", frames[i + 1]); + } + + Assert.StartsWith("PUB foo", frames[101]); await nats.DisposeAsync(); } - [Fact] - public async Task Protocol_parser_under_load() + [Theory] + [InlineData(1)] + [InlineData(1024)] + [InlineData(1024 * 1024)] + public async Task Protocol_parser_under_load(int size) { await using var server = NatsServer.Start(); var logger = new InMemoryTestLoggerFactory(LogLevel.Error); @@ -355,24 +368,27 @@ public async Task Protocol_parser_under_load() using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); var signal = new WaitSignal(); - + var counts = new ConcurrentDictionary(); _ = Task.Run( async () => { var count = 0; - await foreach (var unused in nats.SubscribeAsync("x", cancellationToken: cts.Token)) + await foreach (var msg in nats.SubscribeAsync("x.*", cancellationToken: cts.Token)) { - if (++count > 10_000) + if (++count > 100) signal.Pulse(); + counts.AddOrUpdate(msg.Subject, 1, (_, c) => c + 1); } }, cts.Token); + var payload = new byte[size]; + var r = 0; _ = Task.Run( async () => { while (!cts.Token.IsCancellationRequested) - await nats.PublishAsync("x", "x", cancellationToken: cts.Token); + await nats.PublishAsync($"x.{Volatile.Read(ref r)}", payload, cancellationToken: cts.Token); }, cts.Token); @@ -382,12 +398,21 @@ await foreach (var unused in nats.SubscribeAsync("x", cancellationToken: { await Task.Delay(1_000, cts.Token); await server.RestartAsync(); + Interlocked.Increment(ref r); + await Task.Delay(1_000, cts.Token); } foreach (var log in logger.Logs.Where(x => x.EventId == NatsLogEvents.Protocol && x.LogLevel == LogLevel.Error)) { Assert.DoesNotContain("Unknown Protocol Operation", log.Message); } + + foreach (var (key, value) in counts) + { + _output.WriteLine($"{key} {value}"); + } + + counts.Count.Should().BeGreaterOrEqualTo(3); } private sealed class NatsSubReconnectTest : NatsSubBase @@ -403,9 +428,10 @@ internal override async ValueTask WriteReconnectCommandsAsync(CommandWriter comm await base.WriteReconnectCommandsAsync(commandWriter, sid); // Any additional commands to send on reconnect - await commandWriter.PublishAsync("bar1", default, default, default, NatsRawSerializer.Default, default); - await commandWriter.PublishAsync("bar2", default, default, default, NatsRawSerializer.Default, default); - await commandWriter.PublishAsync("bar3", default, default, default, NatsRawSerializer.Default, default); + for (var i = 0; i < 100; i++) + { + await commandWriter.PublishAsync($"bar{i}", default, default, default, NatsRawSerializer.Default, default); + } } protected override ValueTask ReceiveInternalAsync(string subject, string? replyTo, ReadOnlySequence? headersBuffer, ReadOnlySequence payloadBuffer) diff --git a/tests/NATS.Client.Core.Tests/SendBufferTest.cs b/tests/NATS.Client.Core.Tests/SendBufferTest.cs new file mode 100644 index 00000000..d6d09880 --- /dev/null +++ b/tests/NATS.Client.Core.Tests/SendBufferTest.cs @@ -0,0 +1,78 @@ +using System.Diagnostics; +using NATS.Client.TestUtilities; + +namespace NATS.Client.Core.Tests; + +public class SendBufferTest +{ + private readonly ITestOutputHelper _output; + + public SendBufferTest(ITestOutputHelper output) => _output = output; + + [Fact] + public async Task Send_cancel() + { + void Log(string m) => TmpFileLogger.Log(m); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + + await using var server = new MockServer( + async (s, cmd) => + { + if (cmd.Name == "PUB" && cmd.Subject == "pause") + { + s.Log("[S] pause"); + await Task.Delay(10_000, cts.Token); + } + }, + Log, + cts.Token); + + Log("__________________________________"); + + await using var nats = new NatsConnection(new NatsOpts { Url = server.Url }); + + Log($"[C] connect {server.Url}"); + await nats.ConnectAsync(); + + Log($"[C] ping"); + var rtt = await nats.PingAsync(cts.Token); + Log($"[C] ping rtt={rtt}"); + + server.Log($"[C] publishing pause..."); + await nats.PublishAsync("pause", "x", cancellationToken: cts.Token); + + server.Log($"[C] publishing 1M..."); + var payload = new byte[1024 * 1024]; + var tasks = new List(); + for (var i = 0; i < 10; i++) + { + var i1 = i; + tasks.Add(Task.Run(async () => + { + var stopwatch = Stopwatch.StartNew(); + + try + { + Log($"[C] ({i1}) publish..."); + await nats.PublishAsync("x", payload, cancellationToken: cts.Token); + } + catch (Exception e) + { + stopwatch.Stop(); + Log($"[C] ({i1}) publish cancelled after {stopwatch.Elapsed.TotalSeconds:n0} s (exception: {e.GetType()})"); + return; + } + + stopwatch.Stop(); + Log($"[C] ({i1}) publish took {stopwatch.Elapsed.TotalSeconds:n3} s"); + })); + } + + for (var i = 0; i < 10; i++) + { + Log($"[C] await tasks {i}..."); + await tasks[i]; + } + } +} diff --git a/tests/NATS.Client.TestUtilities/MockServer.cs b/tests/NATS.Client.TestUtilities/MockServer.cs new file mode 100644 index 00000000..d9a0f273 --- /dev/null +++ b/tests/NATS.Client.TestUtilities/MockServer.cs @@ -0,0 +1,112 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Text.RegularExpressions; + +namespace NATS.Client.TestUtilities; + +public class MockServer : IAsyncDisposable +{ + private readonly Action _logger; + private readonly TcpListener _server; + private readonly Task _accept; + + public MockServer( + Func handler, + Action logger, + CancellationToken cancellationToken) + { + _logger = logger; + _server = new TcpListener(IPAddress.Parse("127.0.0.1"), 0); + _server.Start(); + Port = ((IPEndPoint)_server.LocalEndpoint).Port; + + _accept = Task.Run( + async () => + { + var client = await _server.AcceptTcpClientAsync(); + + var stream = client.GetStream(); + + var sw = new StreamWriter(stream, Encoding.ASCII); + await sw.WriteAsync("INFO {}\r\n"); + await sw.FlushAsync(); + + var sr = new StreamReader(stream, Encoding.ASCII); + + while (!cancellationToken.IsCancellationRequested) + { + Log($"[S] >>> READ LINE"); + var line = (await sr.ReadLineAsync())!; + + if (line.StartsWith("CONNECT")) + { + Log($"[S] RCV CONNECT"); + } + else if (line.StartsWith("PING")) + { + Log($"[S] RCV PING"); + await sw.WriteAsync("PONG\r\n"); + await sw.FlushAsync(); + Log($"[S] SND PONG"); + } + else if (line.StartsWith("SUB")) + { + var m = Regex.Match(line, @"^SUB\s+(?\S+)"); + var subject = m.Groups["subject"].Value; + Log($"[S] RCV SUB {subject}"); + await handler(this, new Cmd("SUB", subject, 0)); + } + else if (line.StartsWith("PUB") || line.StartsWith("HPUB")) + { + var m = Regex.Match(line, @"^(H?PUB)\s+(?\S+).*?(?\d+)$"); + var size = int.Parse(m.Groups["size"].Value); + var subject = m.Groups["subject"].Value; + Log($"[S] RCV PUB {subject} {size}"); + var read = 0; + var buffer = new byte[size]; + while (read < size) + { + var received = await stream.ReadAsync(buffer, read, size - read); + read += received; + Log($"[S] RCV {received} bytes (size={size} read={read})"); + } + + await handler(this, new Cmd("PUB", subject, size)); + await sr.ReadLineAsync(); + } + else + { + Log($"[S] RCV LINE: {line}"); + } + } + }, + cancellationToken); + } + + public int Port { get; } + + public string Url => $"127.0.0.1:{Port}"; + + public async ValueTask DisposeAsync() + { + _server.Stop(); + try + { + await _accept; + } + catch (OperationCanceledException) + { + } + catch (SocketException) + { + } + catch (IOException) + { + } + } + + public void Log(string m) => _logger(m); + + public record Cmd(string Name, string Subject, int Size); +} diff --git a/tests/NATS.Client.TestUtilities/TmpFileLogger.cs b/tests/NATS.Client.TestUtilities/TmpFileLogger.cs new file mode 100644 index 00000000..72d8e8ee --- /dev/null +++ b/tests/NATS.Client.TestUtilities/TmpFileLogger.cs @@ -0,0 +1,18 @@ +using System.Text; + +namespace NATS.Client.TestUtilities; + +public static class TmpFileLogger +{ + private static readonly object Gate = new(); + + public static void Log(string m) + { + lock (Gate) + { + using var fs = new FileStream("/tmp/test.log", FileMode.Append, FileAccess.Write, FileShare.ReadWrite); + using var sw = new StreamWriter(fs, Encoding.UTF8); + sw.WriteLine($"{DateTime.Now:HH:mm:ss.fff} {m}"); + } + } +}