Skip to content

Commit

Permalink
Fix memory leaks with AsyncDictionaryHelper
Browse files Browse the repository at this point in the history
To ensure that entries from _pendingRequests are removed when removing a key from
_dictionary.
  • Loading branch information
jnyrup committed Aug 17, 2023
1 parent 6d712e9 commit f62ec11
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 37 deletions.
2 changes: 1 addition & 1 deletion lib/PuppeteerSharp/Browser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public bool IsClosed

/// <inheritdoc/>
public ITarget[] Targets()
=> TargetManager.GetAvailableTargets().InnerDictionary.Values.ToArray();
=> TargetManager.GetAvailableTargets().Values.ToArray();

/// <inheritdoc/>
public async Task<IBrowserContext> CreateIncognitoBrowserContextAsync()
Expand Down
16 changes: 7 additions & 9 deletions lib/PuppeteerSharp/ChromeTargetManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ internal class ChromeTargetManager : ITargetManager
private readonly Func<TargetInfo, CDPSession, Target> _targetFactoryFunc;
private readonly Func<TargetInfo, bool> _targetFilterFunc;
private readonly ILogger<ChromeTargetManager> _logger;
private readonly ConcurrentDictionary<string, Target> _availableTargetsByTargetIdDictionary = new();
private readonly AsyncDictionaryHelper<string, Target> _attachedTargetsByTargetId;
private readonly AsyncDictionaryHelper<string, Target> _attachedTargetsByTargetId = new("Target {0} not found");
private readonly ConcurrentDictionary<string, Target> _attachedTargetsBySessionId = new();
private readonly ConcurrentDictionary<string, TargetInfo> _discoveredTargetsByTargetId = new();
private readonly ConcurrentDictionary<ICDPConnection, List<TargetInterceptor>> _targetInterceptors = new();
Expand All @@ -37,7 +36,6 @@ internal class ChromeTargetManager : ITargetManager
Func<TargetInfo, bool> targetFilterFunc,
int targetDiscoveryTimeout = 0)
{
_attachedTargetsByTargetId = new AsyncDictionaryHelper<string, Target>(_availableTargetsByTargetIdDictionary, "Target {0} not found");
_connection = connection;
_targetFilterFunc = targetFilterFunc;
_targetFactoryFunc = targetFactoryFunc;
Expand Down Expand Up @@ -188,7 +186,7 @@ private void OnTargetCreated(TargetCreatedResponse e)

if (e.TargetInfo.Type == TargetType.Browser && e.TargetInfo.Attached)
{
if (_availableTargetsByTargetIdDictionary.ContainsKey(e.TargetInfo.TargetId))
if (_attachedTargetsByTargetId.ContainsKey(e.TargetInfo.TargetId))
{
return;
}
Expand All @@ -204,7 +202,7 @@ private async void OnTargetDestroyed(TargetDestroyedResponse e)
await EnsureTargetsIdsForInit().ConfigureAwait(false);
FinishInitializationIfReady(e.TargetId);

if (targetInfo?.Type == TargetType.ServiceWorker && _availableTargetsByTargetIdDictionary.TryRemove(e.TargetId, out var target))
if (targetInfo?.Type == TargetType.ServiceWorker && _attachedTargetsByTargetId.TryRemove(e.TargetId, out var target))
{
TargetGone?.Invoke(this, new TargetChangedArgs { Target = target, TargetInfo = targetInfo });
}
Expand All @@ -215,7 +213,7 @@ private void OnTargetInfoChanged(TargetCreatedResponse e)
_discoveredTargetsByTargetId[e.TargetInfo.TargetId] = e.TargetInfo;

if (_ignoredTargets.Contains(e.TargetInfo.TargetId) ||
!_availableTargetsByTargetIdDictionary.TryGetValue(e.TargetInfo.TargetId, out var target) ||
!_attachedTargetsByTargetId.TryGetValue(e.TargetInfo.TargetId, out var target) ||
!e.TargetInfo.Attached)
{
return;
Expand Down Expand Up @@ -259,7 +257,7 @@ async Task SilentDetach()
await EnsureTargetsIdsForInit().ConfigureAwait(false);
FinishInitializationIfReady(targetInfo.TargetId);
await SilentDetach().ConfigureAwait(false);
if (_availableTargetsByTargetIdDictionary.ContainsKey(targetInfo.TargetId))
if (_attachedTargetsByTargetId.ContainsKey(targetInfo.TargetId))
{
return;
}
Expand All @@ -279,7 +277,7 @@ async Task SilentDetach()
return;
}

var existingTarget = _availableTargetsByTargetIdDictionary.TryGetValue(targetInfo.TargetId, out var target);
var existingTarget = _attachedTargetsByTargetId.TryGetValue(targetInfo.TargetId, out var target);
if (!existingTarget)
{
target = _targetFactoryFunc(targetInfo, session);
Expand Down Expand Up @@ -377,7 +375,7 @@ private void OnDetachedFromTarget(object sender, TargetDetachedFromTargetRespons
return;
}

_availableTargetsByTargetIdDictionary.TryRemove(target.TargetId, out _);
_attachedTargetsByTargetId.TryRemove(target.TargetId, out _);
TargetGone?.Invoke(this, new TargetChangedArgs { Target = target });
}
}
Expand Down
8 changes: 3 additions & 5 deletions lib/PuppeteerSharp/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ public class Connection : IDisposable, ICDPConnection
private readonly TaskQueue _callbackQueue = new();

private readonly ConcurrentDictionary<int, MessageTask> _callbacks = new();
private readonly ConcurrentDictionary<string, CDPSession> _sessions = new();
private readonly AsyncDictionaryHelper<string, CDPSession> _asyncSessions;
private readonly AsyncDictionaryHelper<string, CDPSession> _sessions = new("Session {0} not found");
private readonly List<string> _manuallyAttached = new();
private int _lastId;

Expand All @@ -38,7 +37,6 @@ internal Connection(string url, int delay, bool enqueueAsyncMessages, IConnectio
_logger = LoggerFactory.CreateLogger<Connection>();

MessageQueue = new AsyncMessageQueue(enqueueAsyncMessages, _logger);
_asyncSessions = new AsyncDictionaryHelper<string, CDPSession>(_sessions, "Session {0} not found");

Transport.MessageReceived += Transport_MessageReceived;
Transport.Closed += Transport_Closed;
Expand Down Expand Up @@ -221,7 +219,7 @@ internal void Close(string closeReason)

internal CDPSession GetSession(string sessionId) => _sessions.GetValueOrDefault(sessionId);

internal Task<CDPSession> GetSessionAsync(string sessionId) => _asyncSessions.GetItemAsync(sessionId);
internal Task<CDPSession> GetSessionAsync(string sessionId) => _sessions.GetItemAsync(sessionId);

/// <summary>
/// Releases all resource used by the <see cref="Connection"/> object.
Expand Down Expand Up @@ -287,7 +285,7 @@ private void ProcessIncomingMessage(ConnectionResponse obj)
{
var sessionId = param.SessionId;
var session = new CDPSession(this, param.TargetInfo.Type, sessionId);
_asyncSessions.AddItem(sessionId, session);
_sessions.AddItem(sessionId, session);

SessionAttached?.Invoke(this, new SessionEventArgs { Session = session });

Expand Down
9 changes: 3 additions & 6 deletions lib/PuppeteerSharp/FirefoxTargetManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ internal class FirefoxTargetManager : ITargetManager
private readonly Func<TargetInfo, bool> _targetFilterFunc;
private readonly ILogger<FirefoxTargetManager> _logger;
private readonly ConcurrentDictionary<ICDPConnection, List<TargetInterceptor>> _targetInterceptors = new();

private readonly ConcurrentDictionary<string, Target> _availableTargetsByTargetIdDictionary = new();
private readonly AsyncDictionaryHelper<string, Target> _availableTargetsByTargetId;
private readonly AsyncDictionaryHelper<string, Target> _availableTargetsByTargetId = new("Target {0} not found");
private readonly ConcurrentDictionary<string, Target> _availableTargetsBySessionId = new();
private readonly ConcurrentDictionary<string, TargetInfo> _discoveredTargetsByTargetId = new();
private readonly TaskCompletionSource<bool> _initializeCompletionSource = new();
Expand All @@ -31,7 +29,6 @@ internal class FirefoxTargetManager : ITargetManager
Func<TargetInfo, CDPSession, Target> targetFactoryFunc,
Func<TargetInfo, bool> targetFilterFunc)
{
_availableTargetsByTargetId = new AsyncDictionaryHelper<string, Target>(_availableTargetsByTargetIdDictionary, "Target {0} not found");
_connection = connection;
_targetFilterFunc = targetFilterFunc;
_targetFactoryFunc = targetFactoryFunc;
Expand Down Expand Up @@ -148,7 +145,7 @@ private void OnTargetDestroyed(TargetDestroyedResponse e)
_discoveredTargetsByTargetId.TryRemove(e.TargetId, out var targetInfo);
FinishInitializationIfReady(e.TargetId);

if (_availableTargetsByTargetIdDictionary.TryGetValue(e.TargetId, out var target))
if (_availableTargetsByTargetId.TryGetValue(e.TargetId, out var target))
{
TargetGone?.Invoke(this, new TargetChangedArgs { Target = target, TargetInfo = targetInfo });
}
Expand All @@ -159,7 +156,7 @@ private void OnAttachedToTarget(object sender, TargetAttachedToTargetResponse e)
var parent = sender as ICDPConnection;
var targetInfo = e.TargetInfo;
var session = _connection.GetSession(e.SessionId) ?? throw new PuppeteerException($"Session {e.SessionId} was not created.");
var existingTarget = _availableTargetsByTargetIdDictionary.TryGetValue(targetInfo.TargetId, out var target);
var existingTarget = _availableTargetsByTargetId.TryGetValue(targetInfo.TargetId, out var target);
session.MessageReceived += OnMessageReceived;

_availableTargetsBySessionId.TryAdd(session.Id, target);
Expand Down
16 changes: 5 additions & 11 deletions lib/PuppeteerSharp/FrameTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,18 @@ namespace PuppeteerSharp
{
internal class FrameTree
{
private readonly ConcurrentDictionary<string, Frame> _frames = new();
private readonly ConcurrentDictionary<string, string> _parentIds = new();
private readonly ConcurrentDictionary<string, List<string>> _childIds = new();
private readonly ConcurrentDictionary<string, List<TaskCompletionSource<Frame>>> _waitRequests = new();
private readonly AsyncDictionaryHelper<string, Frame> _asyncFrames;

public FrameTree()
{
_asyncFrames = new AsyncDictionaryHelper<string, Frame>(_frames, "Frame {0} not found");
}
private readonly AsyncDictionaryHelper<string, Frame> _frames = new("Frame {0} not found");

public Frame MainFrame { get; set; }

public Frame[] Frames => _frames.Values.ToArray();

internal Task<Frame> GetFrameAsync(string frameId) => _asyncFrames.GetItemAsync(frameId);
internal Task<Frame> GetFrameAsync(string frameId) => _frames.GetItemAsync(frameId);

internal Task<Frame> TryGetFrameAsync(string frameId) => _asyncFrames.TryGetItemAsync(frameId);
internal Task<Frame> TryGetFrameAsync(string frameId) => _frames.TryGetItemAsync(frameId);

internal Frame GetById(string id)
{
Expand All @@ -52,7 +46,7 @@ internal Task<Frame> WaitForFrameAsync(string frameId)

internal void AddFrame(Frame frame)
{
_asyncFrames.AddItem(frame.Id, frame);
_frames.AddItem(frame.Id, frame);
if (frame.ParentId != null)
{
_parentIds.TryAdd(frame.Id, frame.ParentId);
Expand All @@ -78,7 +72,7 @@ internal void AddFrame(Frame frame)

internal void RemoveFrame(Frame frame)
{
_frames.TryRemove(frame.Id, out var _);
_frames.TryRemove(frame.Id, out _);
_parentIds.TryRemove(frame.Id, out var _);

if (frame.ParentId != null)
Expand Down
30 changes: 26 additions & 4 deletions lib/PuppeteerSharp/Helpers/AsyncDictionaryHelper.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Globalization;
using System.Threading.Tasks;

Expand All @@ -9,15 +10,14 @@ internal class AsyncDictionaryHelper<TKey, TValue>
{
private readonly string _timeoutMessage;
private readonly MultiMap<TKey, TaskCompletionSource<TValue>> _pendingRequests = new();
private readonly ConcurrentDictionary<TKey, TValue> _dictionary;
private readonly ConcurrentDictionary<TKey, TValue> _dictionary = new();

public AsyncDictionaryHelper(ConcurrentDictionary<TKey, TValue> dictionary, string timeoutMessage)
public AsyncDictionaryHelper(string timeoutMessage)
{
_dictionary = dictionary;
_timeoutMessage = timeoutMessage;
}

internal ConcurrentDictionary<TKey, TValue> InnerDictionary => _dictionary;
internal ICollection<TValue> Values => _dictionary.Values;

internal async Task<TValue> GetItemAsync(TKey key)
{
Expand Down Expand Up @@ -58,5 +58,27 @@ internal void AddItem(TKey key, TValue value)
tcs.TrySetResult(value);
}
}

internal bool TryRemove(TKey key, out TValue value)
{
var result = _dictionary.TryRemove(key, out value);
_ = _pendingRequests.TryRemove(key, out _);
return result;
}

internal void Clear()
{
_dictionary.Clear();
_pendingRequests.Clear();
}

internal TValue GetValueOrDefault(TKey key)
=> _dictionary.GetValueOrDefault(key);

internal bool TryGetValue(TKey key, out TValue value)
=> _dictionary.TryGetValue(key, out value);

internal bool ContainsKey(TKey key)
=> _dictionary.ContainsKey(key);
}
}
6 changes: 6 additions & 0 deletions lib/PuppeteerSharp/Helpers/MultiMap.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ internal bool Has(TKey key, TValue value)
internal bool Delete(TKey key, TValue value)
=> _map.TryGetValue(key, out var set) && set.Remove(value);

internal bool TryRemove(TKey key, out ICollection<TValue> value)
=> _map.TryRemove(key, out value);

internal TValue FirstValue(TKey key)
=> _map.TryGetValue(key, out var set) ? set.FirstOrDefault() : default;

internal void Clear()
=> _map.Clear();
}
}
2 changes: 1 addition & 1 deletion lib/PuppeteerSharp/Target.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public class Target : ITarget

/// <inheritdoc/>
public ITarget Opener => TargetInfo.OpenerId != null ?
((Browser)Browser).TargetManager.GetAvailableTargets().InnerDictionary.GetValueOrDefault(TargetInfo.OpenerId) : null;
((Browser)Browser).TargetManager.GetAvailableTargets().GetValueOrDefault(TargetInfo.OpenerId) : null;

/// <inheritdoc/>
public IBrowser Browser => BrowserContext.Browser;
Expand Down

0 comments on commit f62ec11

Please sign in to comment.