Skip to content

Commit

Permalink
Load ssl certificate for each request
Browse files Browse the repository at this point in the history
  • Loading branch information
jjxtra committed Aug 26, 2018
1 parent 88d7269 commit cc45b36
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
4 changes: 2 additions & 2 deletions MailDemonLog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ public static void Error(Exception ex)
/// </summary>
/// <param name="text">Text</param>
/// <param name="ex">Error</param>
public static void Error(string text, Exception ex)
public static void Error(string text, Exception ex = null)
{
Write(LogLevel.Error, text + ": " + ex.ToString());
Write(LogLevel.Error, text + (ex == null ? string.Empty : ": " + ex.ToString()));
}

/// <summary>
Expand Down
64 changes: 36 additions & 28 deletions MailDemonService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ private class CacheEntry

private TcpListenerActive server;

private readonly Encoding utf8Encoding = new UTF8Encoding(false);
private readonly List<MailDemonUser> users = new List<MailDemonUser>();
private readonly MemoryCache cache = new MemoryCache(new MemoryCacheOptions { SizeLimit = (1024 * 1024 * 16), CompactionPercentage = 0.9 });
private readonly Dictionary<string, Regex> ignoreCertificateErrorsRegex = new Dictionary<string, Regex>(StringComparer.OrdinalIgnoreCase); // domain,regex
Expand All @@ -68,9 +69,6 @@ private class CacheEntry
private readonly string sslCertificateFile;
private readonly string sslCertificatePrivateKeyFile;
private readonly SecureString sslCertificatePassword;
private readonly Timer sslCertificateTimer;

private X509Certificate2 sslCertificate;

public string Domain { get; private set; }
public IReadOnlyList<MailDemonUser> Users { get { return users; } }
Expand Down Expand Up @@ -104,8 +102,6 @@ public MailDemonService(string[] args, IConfiguration configuration)
{
sslCertificatePassword.AppendChar(c);
}
sslCertificateTimer = new Timer(SslCertificateTimerCallback, null, TimeSpan.FromDays(1.0), TimeSpan.FromDays(1.0));
LoadSslCertificate();
}
IConfigurationSection ignoreRegexSection = rootSection.GetSection("ignoreCertificateErrorsRegex");
if (ignoreRegexSection != null)
Expand Down Expand Up @@ -190,6 +186,7 @@ private async Task ProcessConnection()
MailDemonUser authenticatedUser = null;
client.ReceiveTimeout = 5000;
client.SendTimeout = 5000;
X509Certificate2 currentSslCertificate = null;

try
{
Expand All @@ -208,10 +205,12 @@ private async Task ProcessConnection()
// create comm streams
SslStream sslStream = null;
StreamReader reader = new StreamReader(stream, Encoding.UTF8);
StreamWriter writer = new StreamWriter(stream, new UTF8Encoding(false)) { AutoFlush = true, NewLine = "\r\n" };
StreamWriter writer = new StreamWriter(stream, utf8Encoding) { AutoFlush = true, NewLine = "\r\n" };
currentSslCertificate = LoadSslCertificate();

if (port == 465 || port == 587)
{
Tuple<SslStream, StreamReader, StreamWriter> tls = await StartTls(reader, writer, false);
Tuple<SslStream, StreamReader, StreamWriter> tls = await StartTls(reader, writer, false, currentSslCertificate);
if (tls == null)
{
throw new IOException("Failed to start TLS, ssl certificate failed to load");
Expand Down Expand Up @@ -242,7 +241,7 @@ private async Task ProcessConnection()
}
else if (line.StartsWith("EHLO", StringComparison.OrdinalIgnoreCase))
{
await HandleEhlo(writer, sslStream);
await HandleEhlo(writer, sslStream, currentSslCertificate);
}
else if (line.StartsWith("HELO", StringComparison.OrdinalIgnoreCase))
{
Expand All @@ -260,7 +259,7 @@ private async Task ProcessConnection()
}
else
{
Tuple<SslStream, StreamReader, StreamWriter> tls = await StartTls(reader, writer, true);
Tuple<SslStream, StreamReader, StreamWriter> tls = await StartTls(reader, writer, true, currentSslCertificate);
if (tls == null)
{
await writer.WriteLineAsync("503 Failed to start TLS");
Expand Down Expand Up @@ -309,6 +308,7 @@ private async Task ProcessConnection()
}
finally
{
currentSslCertificate?.Dispose();
MailDemonLog.Write(LogLevel.Info, "{0} disconnected", ipAddress);
}
}
Expand All @@ -321,7 +321,7 @@ private async Task<string> ReadLineAsync(StreamReader reader)
return line;
}

private async Task HandleEhlo(StreamWriter writer, SslStream sslStream)
private async Task HandleEhlo(StreamWriter writer, SslStream sslStream, X509Certificate2 sslCertificate)
{
await writer.WriteLineAsync($"250-SIZE 65536");
await writer.WriteLineAsync($"250-8BITMIME");
Expand Down Expand Up @@ -371,7 +371,7 @@ private async Task<MailDemonUser> Authenticate(StreamReader reader, StreamWriter
throw new InvalidOperationException("Authentication failed");
}

private async Task<Tuple<SslStream, StreamReader, StreamWriter>> StartTls(StreamReader reader, StreamWriter writer, bool sendReadyCommand)
private async Task<Tuple<SslStream, StreamReader, StreamWriter>> StartTls(StreamReader reader, StreamWriter writer, bool sendReadyCommand, X509Certificate2 sslCertificate)
{
if (sslCertificate == null)
{
Expand All @@ -390,9 +390,11 @@ private async Task<Tuple<SslStream, StreamReader, StreamWriter>> StartTls(Stream
try
{
// this can hang if the client does not authenticate ssl properly, so we kill it after 5 seconds
if (!sslStream.AuthenticateAsServerAsync(sslCertificate, false, System.Security.Authentication.SslProtocols.Tls12, true).Wait(5000))
Task authTask = sslStream.AuthenticateAsServerAsync(sslCertificate, false, System.Security.Authentication.SslProtocols.Tls12, true);
if (!authTask.Wait(5000))
{
// forces the authenticate as server to fail and throw exception
MailDemonLog.Error("Unable to authenticate as server, timeout");
sslStream.Dispose();
}
}
Expand All @@ -403,7 +405,7 @@ private async Task<Tuple<SslStream, StreamReader, StreamWriter>> StartTls(Stream

// create comm streams on top of ssl stream
StreamReader sslReader = new StreamReader(sslStream, Encoding.UTF8);
StreamWriter sslWriter = new StreamWriter(sslStream, new UTF8Encoding(false)) { AutoFlush = true, NewLine = "\r\n" };
StreamWriter sslWriter = new StreamWriter(sslStream, utf8Encoding) { AutoFlush = true, NewLine = "\r\n" };

return new Tuple<SslStream, StreamReader, StreamWriter>(sslStream, sslReader, sslWriter);
}
Expand Down Expand Up @@ -601,11 +603,6 @@ private void IncrementFailure(string ipAddress)
Interlocked.Increment(ref entry.Count);
}

private void SslCertificateTimerCallback(object state)
{
LoadSslCertificate();
}

private RSACryptoServiceProvider GetRSAProviderForPrivateKey(string pemPrivateKey)
{
RSACryptoServiceProvider rsaKey = new RSACryptoServiceProvider();
Expand All @@ -616,21 +613,32 @@ private RSACryptoServiceProvider GetRSAProviderForPrivateKey(string pemPrivateKe
return rsaKey;
}

private void LoadSslCertificate()
private X509Certificate2 LoadSslCertificate()
{
try
if (sslCertificatePassword != null)
{
sslCertificate = new X509Certificate2(File.ReadAllBytes(sslCertificateFile), sslCertificatePassword);
if (!sslCertificate.HasPrivateKey && !string.IsNullOrWhiteSpace(sslCertificatePrivateKeyFile))
for (int i = 0; i < 2; i++)
{
sslCertificate = sslCertificate.CopyWithPrivateKey(GetRSAProviderForPrivateKey(File.ReadAllText(sslCertificatePrivateKeyFile)));
try
{
X509Certificate2 newSslCertificate = new X509Certificate2(File.ReadAllBytes(sslCertificateFile), sslCertificatePassword);
if (!newSslCertificate.HasPrivateKey && !string.IsNullOrWhiteSpace(sslCertificatePrivateKeyFile))
{
newSslCertificate = newSslCertificate.CopyWithPrivateKey(GetRSAProviderForPrivateKey(File.ReadAllText(sslCertificatePrivateKeyFile)));
}
MailDemonLog.Write(LogLevel.Warn, "Loaded ssl certificate {0}", newSslCertificate);
return newSslCertificate;
}
catch (Exception ex)
{
MailDemonLog.Write(LogLevel.Error, "Error loading ssl certificate: {0}", ex);

// in case something is copying a new certificate, give it a second and try one more time
Thread.Sleep(1000);
}
}
MailDemonLog.Write(LogLevel.Warn, "Loaded ssl certificate {0}", sslCertificate);
}
catch (Exception ex)
{
MailDemonLog.Write(LogLevel.Error, "Error loading ssl certificate: {0}", ex);
}
return null;
}

private async Task SendMessage(MimeMessage msg, InternetAddress from, string domain)
Expand Down

0 comments on commit cc45b36

Please sign in to comment.