diff --git a/src/Titanium.Web.Proxy/Network/CachedCertificate.cs b/src/Titanium.Web.Proxy/Network/CachedCertificate.cs index 818bb812c..c68904fc1 100644 --- a/src/Titanium.Web.Proxy/Network/CachedCertificate.cs +++ b/src/Titanium.Web.Proxy/Network/CachedCertificate.cs @@ -1,23 +1,24 @@ using System; using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; namespace Titanium.Web.Proxy.Network { /// /// An object that holds the cached certificate /// - internal class CachedCertificate + internal sealed class CachedCertificate { - internal CachedCertificate() - { - LastAccess = DateTime.Now; - } - internal X509Certificate2 Certificate { get; set; } /// - /// last time this certificate was used - /// Usefull in determining its cache lifetime + /// Certificate creation task. + /// + internal Task CreationTask { get; set; } + + /// + /// Last time this certificate was used. + /// Useful in determining its cache lifetime. /// internal DateTime LastAccess { get; set; } } diff --git a/src/Titanium.Web.Proxy/Network/CertificateManager.cs b/src/Titanium.Web.Proxy/Network/CertificateManager.cs index a940ed8b4..e39b2a8eb 100644 --- a/src/Titanium.Web.Proxy/Network/CertificateManager.cs +++ b/src/Titanium.Web.Proxy/Network/CertificateManager.cs @@ -4,8 +4,8 @@ using System.Diagnostics; using System.IO; using System.Linq; -using System.Reflection; using System.Security.Cryptography.X509Certificates; +using System.Threading; using System.Threading.Tasks; using Titanium.Web.Proxy.Helpers; using Titanium.Web.Proxy.Network.Certificate; @@ -30,7 +30,6 @@ public enum CertificateEngine /// Bug #468 Reported. /// DefaultWindows = 1 - } /// @@ -45,9 +44,11 @@ public sealed class CertificateManager : IDisposable /// /// Cache dictionary /// - private readonly ConcurrentDictionary certificateCache; + private readonly ConcurrentDictionary cachedCertificates; + + private readonly CancellationTokenSource clearCertificatesTokenSource; - private readonly ConcurrentDictionary> pendingCertificateCreationTasks; + private readonly object rootCertCreationLock; private ICertificateMaker certEngine; @@ -55,12 +56,12 @@ public sealed class CertificateManager : IDisposable private string issuer; - private bool pfxFileExists; - private X509Certificate2 rootCertificate; private string rootCertificateName; + private ICertificateCache certificateCache; + /// /// Initializes a new instance of the class. /// @@ -99,11 +100,14 @@ internal CertificateManager(string rootCertificateName, string rootCertificateIs CertificateEngine = CertificateEngine.BouncyCastle; - certificateCache = new ConcurrentDictionary(); - pendingCertificateCreationTasks = new ConcurrentDictionary>(); - } + cachedCertificates = new ConcurrentDictionary(); + + clearCertificatesTokenSource = new CancellationTokenSource(); - private bool clearCertificates { get; set; } + certificateCache = new DefaultCertificateDiskCache(); + + rootCertCreationLock = new object(); + } /// /// Is the root certificate used by this proxy is valid? @@ -122,7 +126,7 @@ internal CertificateManager(string rootCertificateName, string rootCertificateIs internal bool MachineTrustRoot { get; set; } /// - /// Whether trust operations should be done with elevated privillages + /// Whether trust operations should be done with elevated privileges /// Will prompt with UAC if required. Works only on Windows. /// internal bool TrustRootAsAdministrator { get; set; } @@ -215,14 +219,24 @@ public X509Certificate2 RootCertificate } /// - /// Save all fake certificates in folder "crts" (will be created in proxy dll directory). + /// Save all fake certificates using . /// for can load the certificate and not make new certificate every time. /// public bool SaveFakeCertificates { get; set; } = false; + /// + /// The service to save fake certificates. + /// The default storage saves certificates in folder "crts" (will be created in proxy dll directory). + /// + public ICertificateCache CertificateStorage + { + get => certificateCache; + set => certificateCache = value ?? new DefaultCertificateDiskCache(); + } + /// /// Overwrite Root certificate file. - /// true : replace an existing .pfx file if password is incorect or if RootCertificate = null. + /// true : replace an existing .pfx file if password is incorrect or if RootCertificate = null. /// public bool OverwritePfxFile { get; set; } = true; @@ -241,52 +255,7 @@ public X509Certificate2 RootCertificate /// public void Dispose() { - } - - private string getRootCertificateDirectory() - { - string assemblyLocation = Assembly.GetExecutingAssembly().Location; - - // dynamically loaded assemblies returns string.Empty location - if (assemblyLocation == string.Empty) - { - assemblyLocation = Assembly.GetEntryAssembly().Location; - } - - string path = Path.GetDirectoryName(assemblyLocation); - if (path == null) - { - throw new NullReferenceException(); - } - - return path; - } - - private string getCertificatePath() - { - string path = getRootCertificateDirectory(); - - string certPath = Path.Combine(path, "crts"); - if (!Directory.Exists(certPath)) - { - Directory.CreateDirectory(certPath); - } - - return certPath; - } - - private string getRootCertificatePath() - { - string path = getRootCertificateDirectory(); - - string fileName = PfxFilePath; - if (fileName == string.Empty) - { - fileName = Path.Combine(path, "rootCert.pfx"); - StorageFlag = X509KeyStorageFlags.Exportable; - } - - return fileName; + clearCertificatesTokenSource.Dispose(); } /// @@ -412,43 +381,37 @@ private X509Certificate2 makeCertificate(string certificateName, bool isRootCert /// internal X509Certificate2 CreateCertificate(string certificateName, bool isRootCertificate) { - X509Certificate2 certificate = null; + X509Certificate2 certificate; try { if (!isRootCertificate && SaveFakeCertificates) { - string path = getCertificatePath(); - string subjectName = ProxyConstants.CNRemoverRegex.Replace(certificateName, string.Empty); - subjectName = subjectName.Replace("*", "$x$"); - string certificatePath = Path.Combine(path, subjectName + ".pfx"); + string subjectName = ProxyConstants.CNRemoverRegex + .Replace(certificateName, string.Empty) + .Replace("*", "$x$"); + + try + { + certificate = certificateCache.LoadCertificate(subjectName, StorageFlag); + } + catch (Exception e) + { + ExceptionFunc(new Exception("Failed to load fake certificate.", e)); + certificate = null; + } - if (!File.Exists(certificatePath)) + if (certificate == null) { certificate = makeCertificate(certificateName, false); - // store as cache try { - var exported = certificate.Export(X509ContentType.Pkcs12); - File.WriteAllBytes(certificatePath, exported); + certificateCache.SaveCertificate(subjectName, certificate); } catch (Exception e) { ExceptionFunc(new Exception("Failed to save fake certificate.", e)); } - - } - else - { - try - { - certificate = new X509Certificate2(certificatePath, string.Empty, StorageFlag); - } - catch - { - // if load failed create again - certificate = makeCertificate(certificateName, false); - } } } else @@ -459,6 +422,7 @@ internal X509Certificate2 CreateCertificate(string certificateName, bool isRootC catch (Exception e) { ExceptionFunc(e); + certificate = null; } return certificate; @@ -472,40 +436,41 @@ internal X509Certificate2 CreateCertificate(string certificateName, bool isRootC internal async Task CreateCertificateAsync(string certificateName) { // check in cache first - if (certificateCache.TryGetValue(certificateName, out var cached)) + var item = cachedCertificates.GetOrAdd(certificateName, _ => { - cached.LastAccess = DateTime.Now; - return cached.Certificate; - } + var cached = new CachedCertificate(); + cached.CreationTask = Task.Run(() => + { + var certificate = CreateCertificate(certificateName, false); + + // see http://www.albahari.com/threading/part4.aspx for the explanation + // why Thread.MemoryBarrier is used here and below + cached.Certificate = certificate; + Thread.MemoryBarrier(); + cached.CreationTask = null; + Thread.MemoryBarrier(); + return certificate; + }); - // handle burst requests with same certificate name - // by checking for existing task for same certificate name - if (pendingCertificateCreationTasks.TryGetValue(certificateName, out var task)) - { - return await task; - } + return cached; + }); - // run certificate creation task & add it to pending tasks - task = Task.Run(() => + item.LastAccess = DateTime.Now; + + if (item.Certificate != null) { - var result = CreateCertificate(certificateName, false); - if (result != null) - { - certificateCache.TryAdd(certificateName, new CachedCertificate - { - Certificate = result - }); - } + return item.Certificate; + } - return result; - }); - pendingCertificateCreationTasks.TryAdd(certificateName, task); + // handle burst requests with same certificate name + // by checking for existing task + Thread.MemoryBarrier(); + var task = item.CreationTask; - // cleanup pending tasks & return result - var certificate = await task; - pendingCertificateCreationTasks.TryRemove(certificateName, out task); + Thread.MemoryBarrier(); - return certificate; + // return result + return item.Certificate ?? await task; } /// @@ -513,20 +478,27 @@ internal async Task CreateCertificateAsync(string certificateN /// internal async void ClearIdleCertificates() { - clearCertificates = true; - while (clearCertificates) + var cancellationToken = clearCertificatesTokenSource.Token; + while (!cancellationToken.IsCancellationRequested) { var cutOff = DateTime.Now.AddMinutes(-1 * CertificateCacheTimeOutMinutes); - var outdated = certificateCache.Where(x => x.Value.LastAccess < cutOff).ToList(); + var outdated = cachedCertificates.Where(x => x.Value.LastAccess < cutOff).ToList(); foreach (var cache in outdated) { - certificateCache.TryRemove(cache.Key, out _); + cachedCertificates.TryRemove(cache.Key, out _); } // after a minute come back to check for outdated certificates in cache - await Task.Delay(1000 * 60); + try + { + await Task.Delay(1000 * 60, cancellationToken); + } + catch (TaskCanceledException) + { + return; + } } } @@ -535,7 +507,7 @@ internal async void ClearIdleCertificates() /// internal void StopClearIdleCertificates() { - clearCertificates = false; + clearCertificatesTokenSource.Cancel(); } /// @@ -547,53 +519,68 @@ internal void StopClearIdleCertificates() /// public bool CreateRootCertificate(bool persistToFile = true) { - if (persistToFile && RootCertificate == null) - { - RootCertificate = LoadRootCertificate(); - } - - if (RootCertificate != null) + lock (rootCertCreationLock) { - return true; - } - - if (!OverwritePfxFile && pfxFileExists) - { - return false; - } + if (persistToFile && RootCertificate == null) + { + RootCertificate = LoadRootCertificate(); + } - try - { - RootCertificate = CreateCertificate(RootCertificateName, true); - } - catch (Exception e) - { - ExceptionFunc(e); - } + if (RootCertificate != null) + { + return true; + } - if (persistToFile && RootCertificate != null) - { - try + if (!OverwritePfxFile) { try { - Directory.Delete(getCertificatePath(), true); + var rootCert = certificateCache.LoadRootCertificate(PfxFilePath, PfxPassword, + X509KeyStorageFlags.Exportable); + + if (rootCert != null) + { + return false; + } } catch { - // ignore + // root cert cannot be loaded } + } - string fileName = getRootCertificatePath(); - File.WriteAllBytes(fileName, RootCertificate.Export(X509ContentType.Pkcs12, PfxPassword)); + try + { + RootCertificate = CreateCertificate(RootCertificateName, true); } catch (Exception e) { ExceptionFunc(e); } - } - return RootCertificate != null; + if (persistToFile && RootCertificate != null) + { + try + { + try + { + certificateCache.Clear(); + } + catch + { + // ignore + } + + certificateCache.SaveRootCertificate(PfxFilePath, PfxPassword, RootCertificate); + } + catch (Exception e) + { + ExceptionFunc(e); + } + } + + return RootCertificate != null; + } } /// @@ -602,16 +589,9 @@ public bool CreateRootCertificate(bool persistToFile = true) /// public X509Certificate2 LoadRootCertificate() { - string fileName = getRootCertificatePath(); - pfxFileExists = File.Exists(fileName); - if (!pfxFileExists) - { - return null; - } - try { - return new X509Certificate2(fileName, PfxPassword, StorageFlag); + return certificateCache.LoadRootCertificate(PfxFilePath, PfxPassword, X509KeyStorageFlags.Exportable); } catch (Exception e) { @@ -629,7 +609,7 @@ public X509Certificate2 LoadRootCertificate() /// /// Set a password for the .pfx file. /// - /// true : replace an existing .pfx file if password is incorect or if + /// true : replace an existing .pfx file if password is incorrect or if /// RootCertificate==null. /// /// @@ -861,7 +841,7 @@ public bool RemoveTrustedRootCertificateAsAdmin(bool machineTrusted = false) ErrorDialog = false, WindowStyle = ProcessWindowStyle.Hidden }, - + // currentUser\Personal & currentMachine\Personal new ProcessStartInfo { @@ -905,6 +885,7 @@ public bool RemoveTrustedRootCertificateAsAdmin(bool machineTrusted = false) public void ClearRootCertificate() { certificateCache.Clear(); + cachedCertificates.Clear(); rootCertificate = null; } } diff --git a/src/Titanium.Web.Proxy/Network/DefaultCertificateDiskCache.cs b/src/Titanium.Web.Proxy/Network/DefaultCertificateDiskCache.cs new file mode 100644 index 000000000..0690c86d4 --- /dev/null +++ b/src/Titanium.Web.Proxy/Network/DefaultCertificateDiskCache.cs @@ -0,0 +1,123 @@ +using System; +using System.IO; +using System.Reflection; +using System.Security.Cryptography.X509Certificates; + +namespace Titanium.Web.Proxy.Network +{ + internal sealed class DefaultCertificateDiskCache : ICertificateCache + { + private const string defaultCertificateDirectoryName = "crts"; + private const string defaultCertificateFileExtension = ".pfx"; + private const string defaultRootCertificateFileName = "rootCert" + defaultCertificateFileExtension; + private string rootCertificatePath; + private string certificatePath; + + public X509Certificate2 LoadRootCertificate(string name, string password, X509KeyStorageFlags storageFlags) + { + string filePath = getRootCertificatePath(name); + return loadCertificate(filePath, password, storageFlags); + } + + public void SaveRootCertificate(string name, string password, X509Certificate2 certificate) + { + string filePath = getRootCertificatePath(name); + byte[] exported = certificate.Export(X509ContentType.Pkcs12, password); + File.WriteAllBytes(filePath, exported); + } + + /// + public X509Certificate2 LoadCertificate(string subjectName, X509KeyStorageFlags storageFlags) + { + string filePath = Path.Combine(getCertificatePath(), subjectName + defaultCertificateFileExtension); + return loadCertificate(filePath, string.Empty, storageFlags); + } + + /// + public void SaveCertificate(string subjectName, X509Certificate2 certificate) + { + string filePath = Path.Combine(getCertificatePath(), subjectName + defaultCertificateFileExtension); + byte[] exported = certificate.Export(X509ContentType.Pkcs12); + File.WriteAllBytes(filePath, exported); + } + + public void Clear() + { + try + { + Directory.Delete(getCertificatePath(), true); + } + catch (DirectoryNotFoundException) + { + // do nothing + } + + certificatePath = null; + } + + private X509Certificate2 loadCertificate(string filePath, string password, X509KeyStorageFlags storageFlags) + { + byte[] exported; + try + { + exported = File.ReadAllBytes(filePath); + } + catch (IOException) + { + // file or directory not found + return null; + } + + return new X509Certificate2(exported, password, storageFlags); + } + + private string getRootCertificatePath(string filePath) + { + if (Path.IsPathRooted(filePath)) + { + return filePath; + } + + return Path.Combine(getRootCertificateDirectory(), + string.IsNullOrEmpty(filePath) ? defaultRootCertificateFileName : filePath); + } + + private string getCertificatePath() + { + if (certificatePath == null) + { + string path = getRootCertificateDirectory(); + + string certPath = Path.Combine(path, defaultCertificateDirectoryName); + if (!Directory.Exists(certPath)) + { + Directory.CreateDirectory(certPath); + } + + certificatePath = certPath; + } + + return certificatePath; + } + + private string getRootCertificateDirectory() + { + if (rootCertificatePath == null) + { + string assemblyLocation = GetType().Assembly.Location; + + // dynamically loaded assemblies returns string.Empty location + if (assemblyLocation == string.Empty) + { + assemblyLocation = Assembly.GetEntryAssembly().Location; + } + + string path = Path.GetDirectoryName(assemblyLocation); + + rootCertificatePath = path ?? throw new NullReferenceException(); + } + + return rootCertificatePath; + } + } +} diff --git a/src/Titanium.Web.Proxy/Network/ICertificateCache.cs b/src/Titanium.Web.Proxy/Network/ICertificateCache.cs new file mode 100644 index 000000000..bd62348f8 --- /dev/null +++ b/src/Titanium.Web.Proxy/Network/ICertificateCache.cs @@ -0,0 +1,32 @@ +using System.Security.Cryptography.X509Certificates; + +namespace Titanium.Web.Proxy.Network +{ + public interface ICertificateCache + { + /// + /// Loads the root certificate from the storage. + /// + X509Certificate2 LoadRootCertificate(string name, string password, X509KeyStorageFlags storageFlags); + + /// + /// Saves the root certificate to the storage. + /// + void SaveRootCertificate(string name, string password, X509Certificate2 certificate); + + /// + /// Loads certificate from the storage. Returns true if certificate does not exist. + /// + X509Certificate2 LoadCertificate(string subjectName, X509KeyStorageFlags storageFlags); + + /// + /// Stores certificate into the storage. + /// + void SaveCertificate(string subjectName, X509Certificate2 certificate); + + /// + /// Clears the storage. + /// + void Clear(); + } +}