diff --git a/src/Titanium.Web.Proxy/Network/CachedCertificate.cs b/src/Titanium.Web.Proxy/Network/CachedCertificate.cs
index 818bb812c..c68904fc1 100644
--- a/src/Titanium.Web.Proxy/Network/CachedCertificate.cs
+++ b/src/Titanium.Web.Proxy/Network/CachedCertificate.cs
@@ -1,23 +1,24 @@
using System;
using System.Security.Cryptography.X509Certificates;
+using System.Threading.Tasks;
namespace Titanium.Web.Proxy.Network
{
///
/// An object that holds the cached certificate
///
- internal class CachedCertificate
+ internal sealed class CachedCertificate
{
- internal CachedCertificate()
- {
- LastAccess = DateTime.Now;
- }
-
internal X509Certificate2 Certificate { get; set; }
///
- /// last time this certificate was used
- /// Usefull in determining its cache lifetime
+ /// Certificate creation task.
+ ///
+ internal Task CreationTask { get; set; }
+
+ ///
+ /// Last time this certificate was used.
+ /// Useful in determining its cache lifetime.
///
internal DateTime LastAccess { get; set; }
}
diff --git a/src/Titanium.Web.Proxy/Network/CertificateManager.cs b/src/Titanium.Web.Proxy/Network/CertificateManager.cs
index a940ed8b4..e39b2a8eb 100644
--- a/src/Titanium.Web.Proxy/Network/CertificateManager.cs
+++ b/src/Titanium.Web.Proxy/Network/CertificateManager.cs
@@ -4,8 +4,8 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
-using System.Reflection;
using System.Security.Cryptography.X509Certificates;
+using System.Threading;
using System.Threading.Tasks;
using Titanium.Web.Proxy.Helpers;
using Titanium.Web.Proxy.Network.Certificate;
@@ -30,7 +30,6 @@ public enum CertificateEngine
/// Bug #468 Reported.
///
DefaultWindows = 1
-
}
///
@@ -45,9 +44,11 @@ public sealed class CertificateManager : IDisposable
///
/// Cache dictionary
///
- private readonly ConcurrentDictionary certificateCache;
+ private readonly ConcurrentDictionary cachedCertificates;
+
+ private readonly CancellationTokenSource clearCertificatesTokenSource;
- private readonly ConcurrentDictionary> pendingCertificateCreationTasks;
+ private readonly object rootCertCreationLock;
private ICertificateMaker certEngine;
@@ -55,12 +56,12 @@ public sealed class CertificateManager : IDisposable
private string issuer;
- private bool pfxFileExists;
-
private X509Certificate2 rootCertificate;
private string rootCertificateName;
+ private ICertificateCache certificateCache;
+
///
/// Initializes a new instance of the class.
///
@@ -99,11 +100,14 @@ internal CertificateManager(string rootCertificateName, string rootCertificateIs
CertificateEngine = CertificateEngine.BouncyCastle;
- certificateCache = new ConcurrentDictionary();
- pendingCertificateCreationTasks = new ConcurrentDictionary>();
- }
+ cachedCertificates = new ConcurrentDictionary();
+
+ clearCertificatesTokenSource = new CancellationTokenSource();
- private bool clearCertificates { get; set; }
+ certificateCache = new DefaultCertificateDiskCache();
+
+ rootCertCreationLock = new object();
+ }
///
/// Is the root certificate used by this proxy is valid?
@@ -122,7 +126,7 @@ internal CertificateManager(string rootCertificateName, string rootCertificateIs
internal bool MachineTrustRoot { get; set; }
///
- /// Whether trust operations should be done with elevated privillages
+ /// Whether trust operations should be done with elevated privileges
/// Will prompt with UAC if required. Works only on Windows.
///
internal bool TrustRootAsAdministrator { get; set; }
@@ -215,14 +219,24 @@ public X509Certificate2 RootCertificate
}
///
- /// Save all fake certificates in folder "crts" (will be created in proxy dll directory).
+ /// Save all fake certificates using .
/// for can load the certificate and not make new certificate every time.
///
public bool SaveFakeCertificates { get; set; } = false;
+ ///
+ /// The service to save fake certificates.
+ /// The default storage saves certificates in folder "crts" (will be created in proxy dll directory).
+ ///
+ public ICertificateCache CertificateStorage
+ {
+ get => certificateCache;
+ set => certificateCache = value ?? new DefaultCertificateDiskCache();
+ }
+
///
/// Overwrite Root certificate file.
- /// true : replace an existing .pfx file if password is incorect or if RootCertificate = null.
+ /// true : replace an existing .pfx file if password is incorrect or if RootCertificate = null.
///
public bool OverwritePfxFile { get; set; } = true;
@@ -241,52 +255,7 @@ public X509Certificate2 RootCertificate
///
public void Dispose()
{
- }
-
- private string getRootCertificateDirectory()
- {
- string assemblyLocation = Assembly.GetExecutingAssembly().Location;
-
- // dynamically loaded assemblies returns string.Empty location
- if (assemblyLocation == string.Empty)
- {
- assemblyLocation = Assembly.GetEntryAssembly().Location;
- }
-
- string path = Path.GetDirectoryName(assemblyLocation);
- if (path == null)
- {
- throw new NullReferenceException();
- }
-
- return path;
- }
-
- private string getCertificatePath()
- {
- string path = getRootCertificateDirectory();
-
- string certPath = Path.Combine(path, "crts");
- if (!Directory.Exists(certPath))
- {
- Directory.CreateDirectory(certPath);
- }
-
- return certPath;
- }
-
- private string getRootCertificatePath()
- {
- string path = getRootCertificateDirectory();
-
- string fileName = PfxFilePath;
- if (fileName == string.Empty)
- {
- fileName = Path.Combine(path, "rootCert.pfx");
- StorageFlag = X509KeyStorageFlags.Exportable;
- }
-
- return fileName;
+ clearCertificatesTokenSource.Dispose();
}
///
@@ -412,43 +381,37 @@ private X509Certificate2 makeCertificate(string certificateName, bool isRootCert
///
internal X509Certificate2 CreateCertificate(string certificateName, bool isRootCertificate)
{
- X509Certificate2 certificate = null;
+ X509Certificate2 certificate;
try
{
if (!isRootCertificate && SaveFakeCertificates)
{
- string path = getCertificatePath();
- string subjectName = ProxyConstants.CNRemoverRegex.Replace(certificateName, string.Empty);
- subjectName = subjectName.Replace("*", "$x$");
- string certificatePath = Path.Combine(path, subjectName + ".pfx");
+ string subjectName = ProxyConstants.CNRemoverRegex
+ .Replace(certificateName, string.Empty)
+ .Replace("*", "$x$");
+
+ try
+ {
+ certificate = certificateCache.LoadCertificate(subjectName, StorageFlag);
+ }
+ catch (Exception e)
+ {
+ ExceptionFunc(new Exception("Failed to load fake certificate.", e));
+ certificate = null;
+ }
- if (!File.Exists(certificatePath))
+ if (certificate == null)
{
certificate = makeCertificate(certificateName, false);
- // store as cache
try
{
- var exported = certificate.Export(X509ContentType.Pkcs12);
- File.WriteAllBytes(certificatePath, exported);
+ certificateCache.SaveCertificate(subjectName, certificate);
}
catch (Exception e)
{
ExceptionFunc(new Exception("Failed to save fake certificate.", e));
}
-
- }
- else
- {
- try
- {
- certificate = new X509Certificate2(certificatePath, string.Empty, StorageFlag);
- }
- catch
- {
- // if load failed create again
- certificate = makeCertificate(certificateName, false);
- }
}
}
else
@@ -459,6 +422,7 @@ internal X509Certificate2 CreateCertificate(string certificateName, bool isRootC
catch (Exception e)
{
ExceptionFunc(e);
+ certificate = null;
}
return certificate;
@@ -472,40 +436,41 @@ internal X509Certificate2 CreateCertificate(string certificateName, bool isRootC
internal async Task CreateCertificateAsync(string certificateName)
{
// check in cache first
- if (certificateCache.TryGetValue(certificateName, out var cached))
+ var item = cachedCertificates.GetOrAdd(certificateName, _ =>
{
- cached.LastAccess = DateTime.Now;
- return cached.Certificate;
- }
+ var cached = new CachedCertificate();
+ cached.CreationTask = Task.Run(() =>
+ {
+ var certificate = CreateCertificate(certificateName, false);
+
+ // see http://www.albahari.com/threading/part4.aspx for the explanation
+ // why Thread.MemoryBarrier is used here and below
+ cached.Certificate = certificate;
+ Thread.MemoryBarrier();
+ cached.CreationTask = null;
+ Thread.MemoryBarrier();
+ return certificate;
+ });
- // handle burst requests with same certificate name
- // by checking for existing task for same certificate name
- if (pendingCertificateCreationTasks.TryGetValue(certificateName, out var task))
- {
- return await task;
- }
+ return cached;
+ });
- // run certificate creation task & add it to pending tasks
- task = Task.Run(() =>
+ item.LastAccess = DateTime.Now;
+
+ if (item.Certificate != null)
{
- var result = CreateCertificate(certificateName, false);
- if (result != null)
- {
- certificateCache.TryAdd(certificateName, new CachedCertificate
- {
- Certificate = result
- });
- }
+ return item.Certificate;
+ }
- return result;
- });
- pendingCertificateCreationTasks.TryAdd(certificateName, task);
+ // handle burst requests with same certificate name
+ // by checking for existing task
+ Thread.MemoryBarrier();
+ var task = item.CreationTask;
- // cleanup pending tasks & return result
- var certificate = await task;
- pendingCertificateCreationTasks.TryRemove(certificateName, out task);
+ Thread.MemoryBarrier();
- return certificate;
+ // return result
+ return item.Certificate ?? await task;
}
///
@@ -513,20 +478,27 @@ internal async Task CreateCertificateAsync(string certificateN
///
internal async void ClearIdleCertificates()
{
- clearCertificates = true;
- while (clearCertificates)
+ var cancellationToken = clearCertificatesTokenSource.Token;
+ while (!cancellationToken.IsCancellationRequested)
{
var cutOff = DateTime.Now.AddMinutes(-1 * CertificateCacheTimeOutMinutes);
- var outdated = certificateCache.Where(x => x.Value.LastAccess < cutOff).ToList();
+ var outdated = cachedCertificates.Where(x => x.Value.LastAccess < cutOff).ToList();
foreach (var cache in outdated)
{
- certificateCache.TryRemove(cache.Key, out _);
+ cachedCertificates.TryRemove(cache.Key, out _);
}
// after a minute come back to check for outdated certificates in cache
- await Task.Delay(1000 * 60);
+ try
+ {
+ await Task.Delay(1000 * 60, cancellationToken);
+ }
+ catch (TaskCanceledException)
+ {
+ return;
+ }
}
}
@@ -535,7 +507,7 @@ internal async void ClearIdleCertificates()
///
internal void StopClearIdleCertificates()
{
- clearCertificates = false;
+ clearCertificatesTokenSource.Cancel();
}
///
@@ -547,53 +519,68 @@ internal void StopClearIdleCertificates()
///
public bool CreateRootCertificate(bool persistToFile = true)
{
- if (persistToFile && RootCertificate == null)
- {
- RootCertificate = LoadRootCertificate();
- }
-
- if (RootCertificate != null)
+ lock (rootCertCreationLock)
{
- return true;
- }
-
- if (!OverwritePfxFile && pfxFileExists)
- {
- return false;
- }
+ if (persistToFile && RootCertificate == null)
+ {
+ RootCertificate = LoadRootCertificate();
+ }
- try
- {
- RootCertificate = CreateCertificate(RootCertificateName, true);
- }
- catch (Exception e)
- {
- ExceptionFunc(e);
- }
+ if (RootCertificate != null)
+ {
+ return true;
+ }
- if (persistToFile && RootCertificate != null)
- {
- try
+ if (!OverwritePfxFile)
{
try
{
- Directory.Delete(getCertificatePath(), true);
+ var rootCert = certificateCache.LoadRootCertificate(PfxFilePath, PfxPassword,
+ X509KeyStorageFlags.Exportable);
+
+ if (rootCert != null)
+ {
+ return false;
+ }
}
catch
{
- // ignore
+ // root cert cannot be loaded
}
+ }
- string fileName = getRootCertificatePath();
- File.WriteAllBytes(fileName, RootCertificate.Export(X509ContentType.Pkcs12, PfxPassword));
+ try
+ {
+ RootCertificate = CreateCertificate(RootCertificateName, true);
}
catch (Exception e)
{
ExceptionFunc(e);
}
- }
- return RootCertificate != null;
+ if (persistToFile && RootCertificate != null)
+ {
+ try
+ {
+ try
+ {
+ certificateCache.Clear();
+ }
+ catch
+ {
+ // ignore
+ }
+
+ certificateCache.SaveRootCertificate(PfxFilePath, PfxPassword, RootCertificate);
+ }
+ catch (Exception e)
+ {
+ ExceptionFunc(e);
+ }
+ }
+
+ return RootCertificate != null;
+ }
}
///
@@ -602,16 +589,9 @@ public bool CreateRootCertificate(bool persistToFile = true)
///
public X509Certificate2 LoadRootCertificate()
{
- string fileName = getRootCertificatePath();
- pfxFileExists = File.Exists(fileName);
- if (!pfxFileExists)
- {
- return null;
- }
-
try
{
- return new X509Certificate2(fileName, PfxPassword, StorageFlag);
+ return certificateCache.LoadRootCertificate(PfxFilePath, PfxPassword, X509KeyStorageFlags.Exportable);
}
catch (Exception e)
{
@@ -629,7 +609,7 @@ public X509Certificate2 LoadRootCertificate()
///
/// Set a password for the .pfx file.
///
- /// true : replace an existing .pfx file if password is incorect or if
+ /// true : replace an existing .pfx file if password is incorrect or if
/// RootCertificate==null.
///
///
@@ -861,7 +841,7 @@ public bool RemoveTrustedRootCertificateAsAdmin(bool machineTrusted = false)
ErrorDialog = false,
WindowStyle = ProcessWindowStyle.Hidden
},
-
+
// currentUser\Personal & currentMachine\Personal
new ProcessStartInfo
{
@@ -905,6 +885,7 @@ public bool RemoveTrustedRootCertificateAsAdmin(bool machineTrusted = false)
public void ClearRootCertificate()
{
certificateCache.Clear();
+ cachedCertificates.Clear();
rootCertificate = null;
}
}
diff --git a/src/Titanium.Web.Proxy/Network/DefaultCertificateDiskCache.cs b/src/Titanium.Web.Proxy/Network/DefaultCertificateDiskCache.cs
new file mode 100644
index 000000000..0690c86d4
--- /dev/null
+++ b/src/Titanium.Web.Proxy/Network/DefaultCertificateDiskCache.cs
@@ -0,0 +1,123 @@
+using System;
+using System.IO;
+using System.Reflection;
+using System.Security.Cryptography.X509Certificates;
+
+namespace Titanium.Web.Proxy.Network
+{
+ internal sealed class DefaultCertificateDiskCache : ICertificateCache
+ {
+ private const string defaultCertificateDirectoryName = "crts";
+ private const string defaultCertificateFileExtension = ".pfx";
+ private const string defaultRootCertificateFileName = "rootCert" + defaultCertificateFileExtension;
+ private string rootCertificatePath;
+ private string certificatePath;
+
+ public X509Certificate2 LoadRootCertificate(string name, string password, X509KeyStorageFlags storageFlags)
+ {
+ string filePath = getRootCertificatePath(name);
+ return loadCertificate(filePath, password, storageFlags);
+ }
+
+ public void SaveRootCertificate(string name, string password, X509Certificate2 certificate)
+ {
+ string filePath = getRootCertificatePath(name);
+ byte[] exported = certificate.Export(X509ContentType.Pkcs12, password);
+ File.WriteAllBytes(filePath, exported);
+ }
+
+ ///
+ public X509Certificate2 LoadCertificate(string subjectName, X509KeyStorageFlags storageFlags)
+ {
+ string filePath = Path.Combine(getCertificatePath(), subjectName + defaultCertificateFileExtension);
+ return loadCertificate(filePath, string.Empty, storageFlags);
+ }
+
+ ///
+ public void SaveCertificate(string subjectName, X509Certificate2 certificate)
+ {
+ string filePath = Path.Combine(getCertificatePath(), subjectName + defaultCertificateFileExtension);
+ byte[] exported = certificate.Export(X509ContentType.Pkcs12);
+ File.WriteAllBytes(filePath, exported);
+ }
+
+ public void Clear()
+ {
+ try
+ {
+ Directory.Delete(getCertificatePath(), true);
+ }
+ catch (DirectoryNotFoundException)
+ {
+ // do nothing
+ }
+
+ certificatePath = null;
+ }
+
+ private X509Certificate2 loadCertificate(string filePath, string password, X509KeyStorageFlags storageFlags)
+ {
+ byte[] exported;
+ try
+ {
+ exported = File.ReadAllBytes(filePath);
+ }
+ catch (IOException)
+ {
+ // file or directory not found
+ return null;
+ }
+
+ return new X509Certificate2(exported, password, storageFlags);
+ }
+
+ private string getRootCertificatePath(string filePath)
+ {
+ if (Path.IsPathRooted(filePath))
+ {
+ return filePath;
+ }
+
+ return Path.Combine(getRootCertificateDirectory(),
+ string.IsNullOrEmpty(filePath) ? defaultRootCertificateFileName : filePath);
+ }
+
+ private string getCertificatePath()
+ {
+ if (certificatePath == null)
+ {
+ string path = getRootCertificateDirectory();
+
+ string certPath = Path.Combine(path, defaultCertificateDirectoryName);
+ if (!Directory.Exists(certPath))
+ {
+ Directory.CreateDirectory(certPath);
+ }
+
+ certificatePath = certPath;
+ }
+
+ return certificatePath;
+ }
+
+ private string getRootCertificateDirectory()
+ {
+ if (rootCertificatePath == null)
+ {
+ string assemblyLocation = GetType().Assembly.Location;
+
+ // dynamically loaded assemblies returns string.Empty location
+ if (assemblyLocation == string.Empty)
+ {
+ assemblyLocation = Assembly.GetEntryAssembly().Location;
+ }
+
+ string path = Path.GetDirectoryName(assemblyLocation);
+
+ rootCertificatePath = path ?? throw new NullReferenceException();
+ }
+
+ return rootCertificatePath;
+ }
+ }
+}
diff --git a/src/Titanium.Web.Proxy/Network/ICertificateCache.cs b/src/Titanium.Web.Proxy/Network/ICertificateCache.cs
new file mode 100644
index 000000000..bd62348f8
--- /dev/null
+++ b/src/Titanium.Web.Proxy/Network/ICertificateCache.cs
@@ -0,0 +1,32 @@
+using System.Security.Cryptography.X509Certificates;
+
+namespace Titanium.Web.Proxy.Network
+{
+ public interface ICertificateCache
+ {
+ ///
+ /// Loads the root certificate from the storage.
+ ///
+ X509Certificate2 LoadRootCertificate(string name, string password, X509KeyStorageFlags storageFlags);
+
+ ///
+ /// Saves the root certificate to the storage.
+ ///
+ void SaveRootCertificate(string name, string password, X509Certificate2 certificate);
+
+ ///
+ /// Loads certificate from the storage. Returns true if certificate does not exist.
+ ///
+ X509Certificate2 LoadCertificate(string subjectName, X509KeyStorageFlags storageFlags);
+
+ ///
+ /// Stores certificate into the storage.
+ ///
+ void SaveCertificate(string subjectName, X509Certificate2 certificate);
+
+ ///
+ /// Clears the storage.
+ ///
+ void Clear();
+ }
+}