diff --git a/Examples/Titanium.Web.Proxy.Examples.Basic/ProxyTestController.cs b/Examples/Titanium.Web.Proxy.Examples.Basic/ProxyTestController.cs index 6e1b67067..fe9f16c1c 100644 --- a/Examples/Titanium.Web.Proxy.Examples.Basic/ProxyTestController.cs +++ b/Examples/Titanium.Web.Proxy.Examples.Basic/ProxyTestController.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.IO; using System.Net; using System.Threading.Tasks; using Titanium.Web.Proxy.EventArguments; @@ -17,6 +18,12 @@ public class ProxyTestController public ProxyTestController() { proxyServer = new ProxyServer(); + + //generate root certificate without storing it in file system + //proxyServer.CertificateEngine = Network.CertificateEngine.BouncyCastle; + //proxyServer.CertificateManager.CreateTrustedRootCertificate(false); + //proxyServer.CertificateManager.TrustRootCertificate(); + proxyServer.ExceptionFunc = exception => Console.WriteLine(exception.Message); proxyServer.TrustRootCertificate = true; @@ -40,7 +47,7 @@ public void StartProxy() //Exclude Https addresses you don't want to proxy //Useful for clients that use certificate pinning //for example google.com and dropbox.com - // ExcludedHttpsHostNameRegex = new List() { "google.com", "dropbox.com" } + ExcludedHttpsHostNameRegex = new List() { "dropbox.com" } //Include Https addresses you want to proxy (others will be excluded) //for example github.com @@ -92,18 +99,22 @@ public void Stop() proxyServer.ClientCertificateSelectionCallback -= OnCertificateSelection; proxyServer.Stop(); + + //remove the generated certificates + //proxyServer.CertificateManager.RemoveTrustedRootCertificates(); } //intecept & cancel redirect or update requests public async Task OnRequest(object sender, SessionEventArgs e) { + Console.WriteLine("Active Client Connections:" + ((ProxyServer) sender).ClientConnectionCount); Console.WriteLine(e.WebSession.Request.Url); //read request headers var requestHeaders = e.WebSession.Request.RequestHeaders; var method = e.WebSession.Request.Method.ToUpper(); - if ((method == "POST" || method == "PUT" || method == "PATCH")) + if (method == "POST" || method == "PUT" || method == "PATCH") { //Get/Set request body bytes byte[] bodyBytes = await e.GetRequestBody(); @@ -116,30 +127,32 @@ public async Task OnRequest(object sender, SessionEventArgs e) requestBodyHistory[e.Id] = bodyString; } - //To cancel a request with a custom HTML content - //Filter URL - if (e.WebSession.Request.RequestUri.AbsoluteUri.Contains("google.com")) - { - await e.Ok("" + - "

" + - "Website Blocked" + - "

" + - "

Blocked by titanium web proxy.

" + - "" + - ""); - } - - //Redirect example - if (e.WebSession.Request.RequestUri.AbsoluteUri.Contains("wikipedia.org")) - { - await e.Redirect("https://www.paypal.com"); - } + ////To cancel a request with a custom HTML content + ////Filter URL + //if (e.WebSession.Request.RequestUri.AbsoluteUri.Contains("google.com")) + //{ + // await e.Ok("" + + // "

" + + // "Website Blocked" + + // "

" + + // "

Blocked by titanium web proxy.

" + + // "" + + // ""); + //} + + ////Redirect example + //if (e.WebSession.Request.RequestUri.AbsoluteUri.Contains("wikipedia.org")) + //{ + // await e.Redirect("https://www.paypal.com"); + //} } //Modify response public async Task OnResponse(object sender, SessionEventArgs e) { - if(requestBodyHistory.ContainsKey(e.Id)) + Console.WriteLine("Active Server Connections:" + (sender as ProxyServer).ServerConnectionCount); + + if (requestBodyHistory.ContainsKey(e.Id)) { //access request body by looking up the shared dictionary using requestId var requestBody = requestBodyHistory[e.Id]; @@ -149,14 +162,14 @@ public async Task OnResponse(object sender, SessionEventArgs e) var responseHeaders = e.WebSession.Response.ResponseHeaders; // print out process id of current session - Console.WriteLine($"PID: {e.WebSession.ProcessId.Value}"); + //Console.WriteLine($"PID: {e.WebSession.ProcessId.Value}"); //if (!e.ProxySession.Request.Host.Equals("medeczane.sgk.gov.tr")) return; if (e.WebSession.Request.Method == "GET" || e.WebSession.Request.Method == "POST") { if (e.WebSession.Response.ResponseStatusCode == "200") { - if (e.WebSession.Response.ContentType!=null && e.WebSession.Response.ContentType.Trim().ToLower().Contains("text/html")) + if (e.WebSession.Response.ContentType != null && e.WebSession.Response.ContentType.Trim().ToLower().Contains("text/html")) { byte[] bodyBytes = await e.GetResponseBody(); await e.SetResponseBody(bodyBytes); diff --git a/Titanium.Web.Proxy/CertificateHandler.cs b/Titanium.Web.Proxy/CertificateHandler.cs index 275ffde1c..134798937 100644 --- a/Titanium.Web.Proxy/CertificateHandler.cs +++ b/Titanium.Web.Proxy/CertificateHandler.cs @@ -16,7 +16,7 @@ public partial class ProxyServer /// /// /// - private bool ValidateServerCertificate( + internal bool ValidateServerCertificate( object sender, X509Certificate certificate, X509Chain chain, @@ -65,7 +65,7 @@ private bool ValidateServerCertificate( /// /// /// - private X509Certificate SelectClientCertificate( + internal X509Certificate SelectClientCertificate( object sender, string targetHost, X509CertificateCollection localCertificates, diff --git a/Titanium.Web.Proxy/EventArguments/SessionEventArgs.cs b/Titanium.Web.Proxy/EventArguments/SessionEventArgs.cs index fda283883..f70bf8550 100644 --- a/Titanium.Web.Proxy/EventArguments/SessionEventArgs.cs +++ b/Titanium.Web.Proxy/EventArguments/SessionEventArgs.cs @@ -93,7 +93,7 @@ private async Task ReadRequestBody() { //GET request don't have a request body to read var method = WebSession.Request.Method.ToUpper(); - if ((method != "POST" && method != "PUT" && method != "PATCH")) + if (method != "POST" && method != "PUT" && method != "PATCH") { throw new BodyNotFoundException("Request don't have a body. " + "Please verify that this request is a Http POST/PUT/PATCH and request " + @@ -411,6 +411,7 @@ public async Task Ok(byte[] result, Dictionary headers) { response.ResponseHeaders = headers; } + response.HttpVersion = WebSession.Request.HttpVersion; response.ResponseBody = result; diff --git a/Titanium.Web.Proxy/Extensions/TcpExtensions.cs b/Titanium.Web.Proxy/Extensions/TcpExtensions.cs index 7657aadea..8dc4d5a64 100644 --- a/Titanium.Web.Proxy/Extensions/TcpExtensions.cs +++ b/Titanium.Web.Proxy/Extensions/TcpExtensions.cs @@ -1,4 +1,5 @@ using System.Net.Sockets; +using Titanium.Web.Proxy.Helpers; namespace Titanium.Web.Proxy.Extensions { @@ -32,5 +33,25 @@ internal static bool IsConnected(this Socket client) client.Blocking = blockingState; } } + + /// + /// Gets the local port from a native TCP row object. + /// + /// The TCP row. + /// The local port + internal static int GetLocalPort(this NativeMethods.TcpRow tcpRow) + { + return (tcpRow.localPort1 << 8) + tcpRow.localPort2 + (tcpRow.localPort3 << 24) + (tcpRow.localPort4 << 16); + } + + /// + /// Gets the remote port from a native TCP row object. + /// + /// The TCP row. + /// The remote port + internal static int GetRemotePort(this NativeMethods.TcpRow tcpRow) + { + return (tcpRow.remotePort1 << 8) + tcpRow.remotePort2 + (tcpRow.remotePort3 << 24) + (tcpRow.remotePort4 << 16); + } } } diff --git a/Titanium.Web.Proxy/Helpers/CustomBinaryReader.cs b/Titanium.Web.Proxy/Helpers/CustomBinaryReader.cs index 77028ec46..6d10c3d91 100644 --- a/Titanium.Web.Proxy/Helpers/CustomBinaryReader.cs +++ b/Titanium.Web.Proxy/Helpers/CustomBinaryReader.cs @@ -15,27 +15,14 @@ internal class CustomBinaryReader : IDisposable { private readonly CustomBufferedStream stream; private readonly int bufferSize; + private readonly byte[] staticBuffer; private readonly Encoding encoding; - [ThreadStatic] - private static byte[] staticBufferField; - - private byte[] staticBuffer - { - get - { - if (staticBufferField == null || staticBufferField.Length != bufferSize) - { - staticBufferField = new byte[bufferSize]; - } - - return staticBufferField; - } - } - internal CustomBinaryReader(CustomBufferedStream stream, int bufferSize) { this.stream = stream; + staticBuffer = new byte[bufferSize]; + this.bufferSize = bufferSize; //default to UTF-8 @@ -122,10 +109,13 @@ internal async Task ReadBytesAsync(long totalBytesToRead) { int bytesToRead = bufferSize; + var buffer = staticBuffer; if (totalBytesToRead < bufferSize) + { bytesToRead = (int) totalBytesToRead; + buffer = new byte[bytesToRead]; + } - var buffer = staticBuffer; int bytesRead; var totalBytesRead = 0; diff --git a/Titanium.Web.Proxy/Helpers/CustomBufferedStream.cs b/Titanium.Web.Proxy/Helpers/CustomBufferedStream.cs index 825faa5eb..78fa6db21 100644 --- a/Titanium.Web.Proxy/Helpers/CustomBufferedStream.cs +++ b/Titanium.Web.Proxy/Helpers/CustomBufferedStream.cs @@ -168,6 +168,7 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance if (bufferLength > 0) { await destination.WriteAsync(streamBuffer, bufferPos, bufferLength, cancellationToken); + bufferLength = 0; } await baseStream.CopyToAsync(destination, bufferSize, cancellationToken); @@ -307,6 +308,7 @@ public byte ReadByteFromBuffer() /// /// A task that represents the asynchronous write operation. /// + [DebuggerStepThrough] public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { return baseStream.WriteAsync(buffer, offset, count, cancellationToken); diff --git a/Titanium.Web.Proxy/Helpers/Network.cs b/Titanium.Web.Proxy/Helpers/Network.cs index 121926813..7ee2d9a55 100644 --- a/Titanium.Web.Proxy/Helpers/Network.cs +++ b/Titanium.Web.Proxy/Helpers/Network.cs @@ -8,8 +8,7 @@ internal class NetworkHelper { private static int FindProcessIdFromLocalPort(int port, IpVersion ipVersion) { - var tcpRow = TcpHelper.GetExtendedTcpTable(ipVersion).FirstOrDefault( - row => row.LocalEndPoint.Port == port); + var tcpRow = TcpHelper.GetTcpRowByLocalPort(ipVersion, port); return tcpRow?.ProcessId ?? 0; } diff --git a/Titanium.Web.Proxy/Helpers/Tcp.cs b/Titanium.Web.Proxy/Helpers/Tcp.cs index 8f38b802e..e64da17b6 100644 --- a/Titanium.Web.Proxy/Helpers/Tcp.cs +++ b/Titanium.Web.Proxy/Helpers/Tcp.cs @@ -93,21 +93,21 @@ internal static TcpTable GetExtendedTcpTable(IpVersion ipVersion) var ipVersionValue = ipVersion == IpVersion.Ipv4 ? NativeMethods.AfInet : NativeMethods.AfInet6; - if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, (int) NativeMethods.TcpTableType.OwnerPidAll, 0) != 0) + if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, (int)NativeMethods.TcpTableType.OwnerPidAll, 0) != 0) { try { tcpTable = Marshal.AllocHGlobal(tcpTableLength); - if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, (int) NativeMethods.TcpTableType.OwnerPidAll, 0) == 0) + if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, (int)NativeMethods.TcpTableType.OwnerPidAll, 0) == 0) { - NativeMethods.TcpTable table = (NativeMethods.TcpTable) Marshal.PtrToStructure(tcpTable, typeof(NativeMethods.TcpTable)); + NativeMethods.TcpTable table = (NativeMethods.TcpTable)Marshal.PtrToStructure(tcpTable, typeof(NativeMethods.TcpTable)); - IntPtr rowPtr = (IntPtr) ((long) tcpTable + Marshal.SizeOf(table.length)); + IntPtr rowPtr = (IntPtr)((long)tcpTable + Marshal.SizeOf(table.length)); for (int i = 0; i < table.length; ++i) { - tcpRows.Add(new TcpRow((NativeMethods.TcpRow) Marshal.PtrToStructure(rowPtr, typeof(NativeMethods.TcpRow)))); - rowPtr = (IntPtr) ((long) rowPtr + Marshal.SizeOf(typeof(NativeMethods.TcpRow))); + tcpRows.Add(new TcpRow((NativeMethods.TcpRow)Marshal.PtrToStructure(rowPtr, typeof(NativeMethods.TcpRow)))); + rowPtr = (IntPtr)((long)rowPtr + Marshal.SizeOf(typeof(NativeMethods.TcpRow))); } } } @@ -123,30 +123,71 @@ internal static TcpTable GetExtendedTcpTable(IpVersion ipVersion) return new TcpTable(tcpRows); } + /// + /// Gets the TCP row by local port number. + /// + /// . + internal static TcpRow GetTcpRowByLocalPort(IpVersion ipVersion, int localPort) + { + IntPtr tcpTable = IntPtr.Zero; + int tcpTableLength = 0; + + var ipVersionValue = ipVersion == IpVersion.Ipv4 ? NativeMethods.AfInet : NativeMethods.AfInet6; + + if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, (int)NativeMethods.TcpTableType.OwnerPidAll, 0) != 0) + { + try + { + tcpTable = Marshal.AllocHGlobal(tcpTableLength); + if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, (int)NativeMethods.TcpTableType.OwnerPidAll, 0) == 0) + { + NativeMethods.TcpTable table = (NativeMethods.TcpTable)Marshal.PtrToStructure(tcpTable, typeof(NativeMethods.TcpTable)); + + IntPtr rowPtr = (IntPtr)((long)tcpTable + Marshal.SizeOf(table.length)); + + for (int i = 0; i < table.length; ++i) + { + var tcpRow = (NativeMethods.TcpRow)Marshal.PtrToStructure(rowPtr, typeof(NativeMethods.TcpRow)); + if (tcpRow.GetLocalPort() == localPort) + { + return new TcpRow(tcpRow); + } + + rowPtr = (IntPtr)((long)rowPtr + Marshal.SizeOf(typeof(NativeMethods.TcpRow))); + } + } + } + finally + { + if (tcpTable != IntPtr.Zero) + { + Marshal.FreeHGlobal(tcpTable); + } + } + } + + return null; + } + /// /// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers as prefix /// Usefull for websocket requests /// - /// - /// + /// /// + /// /// /// /// /// - /// - /// - /// - /// /// /// - /// /// - internal static async Task SendRaw(int bufferSize, int connectionTimeOutSeconds, - string remoteHostName, int remotePort, string httpCmd, Version httpVersion, Dictionary requestHeaders, - bool isHttps, SslProtocols supportedProtocols, - RemoteCertificateValidationCallback remoteCertificateValidationCallback, LocalCertificateSelectionCallback localCertificateSelectionCallback, - Stream clientStream, TcpConnectionFactory tcpConnectionFactory, IPEndPoint upStreamEndPoint) + internal static async Task SendRaw(ProxyServer server, + string remoteHostName, int remotePort, + string httpCmd, Version httpVersion, Dictionary requestHeaders, + bool isHttps, + Stream clientStream, TcpConnectionFactory tcpConnectionFactory) { //prepare the prefix content StringBuilder sb = null; @@ -172,11 +213,10 @@ internal static async Task SendRaw(int bufferSize, int connectionTimeOutSeconds, sb.Append(ProxyConstants.NewLine); } - var tcpConnection = await tcpConnectionFactory.CreateClient(bufferSize, connectionTimeOutSeconds, + var tcpConnection = await tcpConnectionFactory.CreateClient(server, remoteHostName, remotePort, httpVersion, isHttps, - supportedProtocols, remoteCertificateValidationCallback, localCertificateSelectionCallback, - null, null, clientStream, upStreamEndPoint); + null, null, clientStream); try { @@ -192,6 +232,7 @@ internal static async Task SendRaw(int bufferSize, int connectionTimeOutSeconds, finally { tcpConnection.Dispose(); + server.ServerConnectionCount--; } } } diff --git a/Titanium.Web.Proxy/Http/HeaderParser.cs b/Titanium.Web.Proxy/Http/HeaderParser.cs new file mode 100644 index 000000000..27d8e051a --- /dev/null +++ b/Titanium.Web.Proxy/Http/HeaderParser.cs @@ -0,0 +1,45 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using Titanium.Web.Proxy.Helpers; +using Titanium.Web.Proxy.Models; +using Titanium.Web.Proxy.Shared; + +namespace Titanium.Web.Proxy.Http +{ + internal static class HeaderParser + { + internal static async Task ReadHeaders(CustomBinaryReader reader, + Dictionary> nonUniqueResponseHeaders, + Dictionary headers) + { + string tmpLine; + while (!string.IsNullOrEmpty(tmpLine = await reader.ReadLineAsync())) + { + var header = tmpLine.Split(ProxyConstants.ColonSplit, 2); + + var newHeader = new HttpHeader(header[0], header[1]); + + //if header exist in non-unique header collection add it there + if (nonUniqueResponseHeaders.ContainsKey(newHeader.Name)) + { + nonUniqueResponseHeaders[newHeader.Name].Add(newHeader); + } + //if header is alread in unique header collection then move both to non-unique collection + else if (headers.ContainsKey(newHeader.Name)) + { + var existing = headers[newHeader.Name]; + + var nonUniqueHeaders = new List { existing, newHeader }; + + nonUniqueResponseHeaders.Add(newHeader.Name, nonUniqueHeaders); + headers.Remove(newHeader.Name); + } + //add to unique header collection + else + { + headers.Add(newHeader.Name, newHeader); + } + } + } + } +} diff --git a/Titanium.Web.Proxy/Http/HttpWebClient.cs b/Titanium.Web.Proxy/Http/HttpWebClient.cs index b91618223..46aa379c7 100644 --- a/Titanium.Web.Proxy/Http/HttpWebClient.cs +++ b/Titanium.Web.Proxy/Http/HttpWebClient.cs @@ -170,10 +170,10 @@ internal async Task ReceiveResponse() var httpVersion = httpResult[0].Trim().ToLower(); - var version = new Version(1, 1); - if (0 == string.CompareOrdinal(httpVersion, "http/1.0")) + var version = HttpHeader.Version11; + if (string.Equals(httpVersion, "HTTP/1.0", StringComparison.OrdinalIgnoreCase)) { - version = new Version(1, 0); + version = HttpHeader.Version10; } Response.HttpVersion = version; @@ -192,8 +192,9 @@ internal async Task ReceiveResponse() await ReceiveResponse(); return; } - else if (Response.ResponseStatusCode.Equals("417") - && Response.ResponseStatusDescription.Equals("expectation failed", StringComparison.CurrentCultureIgnoreCase)) + + if (Response.ResponseStatusCode.Equals("417") + && Response.ResponseStatusDescription.Equals("expectation failed", StringComparison.CurrentCultureIgnoreCase)) { //read next line after expectation failed response Response.ExpectationFailed = true; @@ -204,36 +205,8 @@ internal async Task ReceiveResponse() return; } - //Read the Response headers //Read the response headers in to unique and non-unique header collections - string tmpLine; - while (!string.IsNullOrEmpty(tmpLine = await ServerConnection.StreamReader.ReadLineAsync())) - { - var header = tmpLine.Split(ProxyConstants.ColonSplit, 2); - - var newHeader = new HttpHeader(header[0], header[1]); - - //if header exist in non-unique header collection add it there - if (Response.NonUniqueResponseHeaders.ContainsKey(newHeader.Name)) - { - Response.NonUniqueResponseHeaders[newHeader.Name].Add(newHeader); - } - //if header is alread in unique header collection then move both to non-unique collection - else if (Response.ResponseHeaders.ContainsKey(newHeader.Name)) - { - var existing = Response.ResponseHeaders[newHeader.Name]; - - var nonUniqueHeaders = new List {existing, newHeader}; - - Response.NonUniqueResponseHeaders.Add(newHeader.Name, nonUniqueHeaders); - Response.ResponseHeaders.Remove(newHeader.Name); - } - //add to unique header collection - else - { - Response.ResponseHeaders.Add(newHeader.Name, newHeader); - } - } + await HeaderParser.ReadHeaders(ServerConnection.StreamReader, Response.NonUniqueResponseHeaders, Response.ResponseHeaders); } } } diff --git a/Titanium.Web.Proxy/Http/Request.cs b/Titanium.Web.Proxy/Http/Request.cs index 8adf03842..949cd5570 100644 --- a/Titanium.Web.Proxy/Http/Request.cs +++ b/Titanium.Web.Proxy/Http/Request.cs @@ -241,9 +241,9 @@ public bool ExpectContinue /// request body as string /// internal string RequestBodyString { get; set; } - - + internal bool RequestBodyRead { get; set; } + internal bool RequestLocked { get; set; } /// diff --git a/Titanium.Web.Proxy/Models/HttpHeader.cs b/Titanium.Web.Proxy/Models/HttpHeader.cs index 6f884e720..c5a678e69 100644 --- a/Titanium.Web.Proxy/Models/HttpHeader.cs +++ b/Titanium.Web.Proxy/Models/HttpHeader.cs @@ -9,6 +9,10 @@ namespace Titanium.Web.Proxy.Models /// public class HttpHeader { + internal static Version Version10 = new Version(1, 0); + + internal static Version Version11 = new Version(1, 1); + /// /// Constructor. /// diff --git a/Titanium.Web.Proxy/Network/CertificateManager.cs b/Titanium.Web.Proxy/Network/CertificateManager.cs index 1d11444f2..d3e7547de 100644 --- a/Titanium.Web.Proxy/Network/CertificateManager.cs +++ b/Titanium.Web.Proxy/Network/CertificateManager.cs @@ -29,9 +29,9 @@ public enum CertificateEngine /// /// A class to manage SSL certificates used by this proxy server /// - internal class CertificateManager : IDisposable + public class CertificateManager : IDisposable { - public CertificateEngine Engine + internal CertificateEngine Engine { get { return engine; } set @@ -164,10 +164,13 @@ internal X509Certificate2 LoadRootCertificate() /// /// Attempts to create a RootCertificate /// - /// true if succeeded, else false - internal bool CreateTrustedRootCertificate() + /// if set to true try to load/save the certificate from rootCert.pfx. + /// + /// true if succeeded, else false + /// + public bool CreateTrustedRootCertificate(bool persistToFile = true) { - if (RootCertificate == null) + if (persistToFile && RootCertificate == null) { RootCertificate = LoadRootCertificate(); } @@ -186,7 +189,7 @@ internal bool CreateTrustedRootCertificate() exceptionFunc(e); } - if (RootCertificate != null) + if (persistToFile && RootCertificate != null) { try { @@ -205,7 +208,7 @@ internal bool CreateTrustedRootCertificate() /// /// Trusts the root certificate. /// - internal void TrustRootCertificate() + public void TrustRootCertificate() { //current user TrustRootCertificate(StoreLocation.CurrentUser); @@ -214,6 +217,18 @@ internal void TrustRootCertificate() TrustRootCertificate(StoreLocation.LocalMachine); } + /// + /// Removes the trusted certificates. + /// + public void RemoveTrustedRootCertificates() + { + //current user + RemoveTrustedRootCertificates(StoreLocation.CurrentUser); + + //current system + RemoveTrustedRootCertificates(StoreLocation.LocalMachine); + } + /// /// Create an SSL certificate /// @@ -336,6 +351,46 @@ internal void TrustRootCertificate(StoreLocation storeLocation) } } + /// + /// Remove the Root Certificate trust + /// + /// + /// + internal void RemoveTrustedRootCertificates(StoreLocation storeLocation) + { + if (RootCertificate == null) + { + exceptionFunc( + new Exception("Could not set root certificate" + + " as system proxy since it is null or empty.")); + + return; + } + + X509Store x509RootStore = new X509Store(StoreName.Root, storeLocation); + var x509PersonalStore = new X509Store(StoreName.My, storeLocation); + + try + { + x509RootStore.Open(OpenFlags.ReadWrite); + x509PersonalStore.Open(OpenFlags.ReadWrite); + + x509RootStore.Remove(RootCertificate); + x509PersonalStore.Remove(RootCertificate); + } + catch (Exception e) + { + exceptionFunc( + new Exception("Failed to make system trust root certificate " + + $" for {storeLocation} store location. You may need admin rights.", e)); + } + finally + { + x509RootStore.Close(); + x509PersonalStore.Close(); + } + } + public void Dispose() { } diff --git a/Titanium.Web.Proxy/Network/Tcp/TcpConnection.cs b/Titanium.Web.Proxy/Network/Tcp/TcpConnection.cs index 5b3bdfce5..70ced1fc0 100644 --- a/Titanium.Web.Proxy/Network/Tcp/TcpConnection.cs +++ b/Titanium.Web.Proxy/Network/Tcp/TcpConnection.cs @@ -55,10 +55,10 @@ public void Dispose() { Stream?.Close(); Stream?.Dispose(); + StreamReader?.Dispose(); - TcpClient.LingerState = new LingerOption(true, 0); - TcpClient.Close(); + TcpClient?.Close(); } } } diff --git a/Titanium.Web.Proxy/Network/Tcp/TcpConnectionFactory.cs b/Titanium.Web.Proxy/Network/Tcp/TcpConnectionFactory.cs index c5cf213ee..2bef02038 100644 --- a/Titanium.Web.Proxy/Network/Tcp/TcpConnectionFactory.cs +++ b/Titanium.Web.Proxy/Network/Tcp/TcpConnectionFactory.cs @@ -6,43 +6,35 @@ using System.Net.Security; using Titanium.Web.Proxy.Helpers; using Titanium.Web.Proxy.Models; -using System.Security.Authentication; using System.Linq; using Titanium.Web.Proxy.Extensions; using Titanium.Web.Proxy.Shared; namespace Titanium.Web.Proxy.Network.Tcp { - using System.Net; - /// /// A class that manages Tcp Connection to server used by this proxy server /// internal class TcpConnectionFactory { + /// /// Creates a TCP connection to server /// - /// - /// + /// /// + /// /// /// - /// - /// - /// - /// /// /// /// - /// /// - internal async Task CreateClient(int bufferSize, int connectionTimeOutSeconds, + internal async Task CreateClient(ProxyServer server, string remoteHostName, int remotePort, Version httpVersion, - bool isHttps, SslProtocols supportedSslProtocols, - RemoteCertificateValidationCallback remoteCertificateValidationCallback, LocalCertificateSelectionCallback localCertificateSelectionCallback, + bool isHttps, ExternalProxy externalHttpProxy, ExternalProxy externalHttpsProxy, - Stream clientStream, IPEndPoint upStreamEndPoint) + Stream clientStream) { TcpClient client; CustomBufferedStream stream; @@ -59,11 +51,11 @@ internal async Task CreateClient(int bufferSize, int connectionTi //If this proxy uses another external proxy then create a tunnel request for HTTPS connections if (useHttpsProxy) { - client = new TcpClient(upStreamEndPoint); + client = new TcpClient(server.UpStreamEndPoint); await client.ConnectAsync(externalHttpsProxy.HostName, externalHttpsProxy.Port); - stream = new CustomBufferedStream(client.GetStream(), bufferSize); + stream = new CustomBufferedStream(client.GetStream(), server.BufferSize); - using (var writer = new StreamWriter(stream, Encoding.ASCII, bufferSize, true) {NewLine = ProxyConstants.NewLine}) + using (var writer = new StreamWriter(stream, Encoding.ASCII, server.BufferSize, true) {NewLine = ProxyConstants.NewLine}) { await writer.WriteLineAsync($"CONNECT {remoteHostName}:{remotePort} HTTP/{httpVersion}"); await writer.WriteLineAsync($"Host: {remoteHostName}:{remotePort}"); @@ -79,7 +71,7 @@ internal async Task CreateClient(int bufferSize, int connectionTi writer.Close(); } - using (var reader = new CustomBinaryReader(stream, bufferSize)) + using (var reader = new CustomBinaryReader(stream, server.BufferSize)) { var result = await reader.ReadLineAsync(); @@ -93,19 +85,19 @@ internal async Task CreateClient(int bufferSize, int connectionTi } else { - client = new TcpClient(upStreamEndPoint); + client = new TcpClient(server.UpStreamEndPoint); await client.ConnectAsync(remoteHostName, remotePort); - stream = new CustomBufferedStream(client.GetStream(), bufferSize); + stream = new CustomBufferedStream(client.GetStream(), server.BufferSize); } try { - sslStream = new SslStream(stream, true, remoteCertificateValidationCallback, - localCertificateSelectionCallback); + sslStream = new SslStream(stream, true, server.ValidateServerCertificate, + server.SelectClientCertificate); - await sslStream.AuthenticateAsClientAsync(remoteHostName, null, supportedSslProtocols, false); + await sslStream.AuthenticateAsClientAsync(remoteHostName, null, server.SupportedSslProtocols, false); - stream = new CustomBufferedStream(sslStream, bufferSize); + stream = new CustomBufferedStream(sslStream, server.BufferSize); } catch { @@ -118,24 +110,24 @@ internal async Task CreateClient(int bufferSize, int connectionTi { if (useHttpProxy) { - client = new TcpClient(upStreamEndPoint); + client = new TcpClient(server.UpStreamEndPoint); await client.ConnectAsync(externalHttpProxy.HostName, externalHttpProxy.Port); - stream = new CustomBufferedStream(client.GetStream(), bufferSize); + stream = new CustomBufferedStream(client.GetStream(), server.BufferSize); } else { - client = new TcpClient(upStreamEndPoint); + client = new TcpClient(server.UpStreamEndPoint); await client.ConnectAsync(remoteHostName, remotePort); - stream = new CustomBufferedStream(client.GetStream(), bufferSize); + stream = new CustomBufferedStream(client.GetStream(), server.BufferSize); } } - client.ReceiveTimeout = connectionTimeOutSeconds * 1000; - client.SendTimeout = connectionTimeOutSeconds * 1000; + client.ReceiveTimeout = server.ConnectionTimeOutSeconds * 1000; + client.SendTimeout = server.ConnectionTimeOutSeconds * 1000; - stream.ReadTimeout = connectionTimeOutSeconds * 1000; - stream.WriteTimeout = connectionTimeOutSeconds * 1000; + client.LingerState = new LingerOption(true, 0); + server.ServerConnectionCount++; return new TcpConnection { @@ -145,7 +137,7 @@ internal async Task CreateClient(int bufferSize, int connectionTi Port = remotePort, IsHttps = isHttps, TcpClient = client, - StreamReader = new CustomBinaryReader(stream, bufferSize), + StreamReader = new CustomBinaryReader(stream, server.BufferSize), Stream = stream, Version = httpVersion }; diff --git a/Titanium.Web.Proxy/Network/Tcp/TcpRow.cs b/Titanium.Web.Proxy/Network/Tcp/TcpRow.cs index 9a8a5f6c3..5bc306c05 100644 --- a/Titanium.Web.Proxy/Network/Tcp/TcpRow.cs +++ b/Titanium.Web.Proxy/Network/Tcp/TcpRow.cs @@ -1,4 +1,5 @@ using System.Net; +using Titanium.Web.Proxy.Extensions; using Titanium.Web.Proxy.Helpers; namespace Titanium.Web.Proxy.Network.Tcp @@ -17,24 +18,42 @@ internal TcpRow(NativeMethods.TcpRow tcpRow) { ProcessId = tcpRow.owningPid; - int localPort = (tcpRow.localPort1 << 8) + (tcpRow.localPort2) + (tcpRow.localPort3 << 24) + (tcpRow.localPort4 << 16); - long localAddress = tcpRow.localAddr; - LocalEndPoint = new IPEndPoint(localAddress, localPort); + LocalPort = tcpRow.GetLocalPort(); + LocalAddress = tcpRow.localAddr; - int remotePort = (tcpRow.remotePort1 << 8) + (tcpRow.remotePort2) + (tcpRow.remotePort3 << 24) + (tcpRow.remotePort4 << 16); - long remoteAddress = tcpRow.remoteAddr; - RemoteEndPoint = new IPEndPoint(remoteAddress, remotePort); + RemotePort = tcpRow.GetRemotePort(); + RemoteAddress = tcpRow.remoteAddr; } + /// + /// Gets the local end point address. + /// + internal long LocalAddress { get; } + + /// + /// Gets the local end point port. + /// + internal int LocalPort { get; } + /// /// Gets the local end point. /// - internal IPEndPoint LocalEndPoint { get; } + internal IPEndPoint LocalEndPoint => new IPEndPoint(LocalAddress, LocalPort); + + /// + /// Gets the remote end point address. + /// + internal long RemoteAddress { get; } + + /// + /// Gets the remote end point port. + /// + internal int RemotePort { get; } /// /// Gets the remote end point. /// - internal IPEndPoint RemoteEndPoint { get; } + internal IPEndPoint RemoteEndPoint => new IPEndPoint(RemoteAddress, RemotePort); /// /// Gets the process identifier. diff --git a/Titanium.Web.Proxy/ProxyAuthorizationHandler.cs b/Titanium.Web.Proxy/ProxyAuthorizationHandler.cs index 9c9a78e4e..c8279925c 100644 --- a/Titanium.Web.Proxy/ProxyAuthorizationHandler.cs +++ b/Titanium.Web.Proxy/ProxyAuthorizationHandler.cs @@ -19,78 +19,36 @@ private async Task CheckAuthorization(StreamWriter clientStreamWriter, IEn return true; } - var httpHeaders = headers as HttpHeader[] ?? headers.ToArray(); + var httpHeaders = headers as ICollection ?? headers.ToArray(); try { if (httpHeaders.All(t => t.Name != "Proxy-Authorization")) { - await WriteResponseStatus(new Version(1, 1), "407", - "Proxy Authentication Required", clientStreamWriter); - var response = new Response - { - ResponseHeaders = new Dictionary - { - { - "Proxy-Authenticate", - new HttpHeader("Proxy-Authenticate", "Basic realm=\"TitaniumProxy\"") - }, - {"Proxy-Connection", new HttpHeader("Proxy-Connection", "close")} - } - }; - await WriteResponseHeaders(clientStreamWriter, response); - - await clientStreamWriter.WriteLineAsync(); + await SendAuthentication407Response(clientStreamWriter, "Proxy Authentication Required"); return false; } + var header = httpHeaders.FirstOrDefault(t => t.Name == "Proxy-Authorization"); - if (null == header) throw new NullReferenceException(); + if (header == null) throw new NullReferenceException(); var headerValue = header.Value.Trim(); if (!headerValue.StartsWith("basic", StringComparison.CurrentCultureIgnoreCase)) { //Return not authorized - await WriteResponseStatus(new Version(1, 1), "407", - "Proxy Authentication Invalid", clientStreamWriter); - var response = new Response - { - ResponseHeaders = new Dictionary - { - { - "Proxy-Authenticate", - new HttpHeader("Proxy-Authenticate", "Basic realm=\"TitaniumProxy\"") - }, - {"Proxy-Connection", new HttpHeader("Proxy-Connection", "close")} - } - }; - await WriteResponseHeaders(clientStreamWriter, response); - - await clientStreamWriter.WriteLineAsync(); + await SendAuthentication407Response(clientStreamWriter, "Proxy Authentication Invalid"); return false; } + headerValue = headerValue.Substring(5).Trim(); var decoded = Encoding.UTF8.GetString(Convert.FromBase64String(headerValue)); if (decoded.Contains(":") == false) { //Return not authorized - await WriteResponseStatus(new Version(1, 1), "407", - "Proxy Authentication Invalid", clientStreamWriter); - var response = new Response - { - ResponseHeaders = new Dictionary - { - { - "Proxy-Authenticate", - new HttpHeader("Proxy-Authenticate", "Basic realm=\"TitaniumProxy\"") - }, - {"Proxy-Connection", new HttpHeader("Proxy-Connection", "close")} - } - }; - await WriteResponseHeaders(clientStreamWriter, response); - - await clientStreamWriter.WriteLineAsync(); + await SendAuthentication407Response(clientStreamWriter, "Proxy Authentication Invalid"); return false; } + var username = decoded.Substring(0, decoded.IndexOf(':')); var password = decoded.Substring(decoded.IndexOf(':') + 1); return await AuthenticateUserFunc(username, password); @@ -98,22 +56,27 @@ await WriteResponseStatus(new Version(1, 1), "407", catch (Exception e) { ExceptionFunc(new ProxyAuthorizationException("Error whilst authorizing request", e, httpHeaders)); - //Return not authorized - await WriteResponseStatus(new Version(1, 1), "407", - "Proxy Authentication Invalid", clientStreamWriter); - var response = new Response - { - ResponseHeaders = new Dictionary - { - {"Proxy-Authenticate", new HttpHeader("Proxy-Authenticate", "Basic realm=\"TitaniumProxy\"")}, - {"Proxy-Connection", new HttpHeader("Proxy-Connection", "close")} - } - }; - await WriteResponseHeaders(clientStreamWriter, response); - await clientStreamWriter.WriteLineAsync(); + //Return not authorized + await SendAuthentication407Response(clientStreamWriter, "Proxy Authentication Invalid"); return false; } } + + private async Task SendAuthentication407Response(StreamWriter clientStreamWriter, string description) + { + await WriteResponseStatus(HttpHeader.Version11, "407", description, clientStreamWriter); + var response = new Response + { + ResponseHeaders = new Dictionary + { + {"Proxy-Authenticate", new HttpHeader("Proxy-Authenticate", "Basic realm=\"TitaniumProxy\"")}, + {"Proxy-Connection", new HttpHeader("Proxy-Connection", "close")} + } + }; + await WriteResponseHeaders(clientStreamWriter, response); + + await clientStreamWriter.WriteLineAsync(); + } } } diff --git a/Titanium.Web.Proxy/ProxyServer.cs b/Titanium.Web.Proxy/ProxyServer.cs index 75ee12fef..de3a0d928 100644 --- a/Titanium.Web.Proxy/ProxyServer.cs +++ b/Titanium.Web.Proxy/ProxyServer.cs @@ -24,11 +24,6 @@ public partial class ProxyServer : IDisposable /// private bool proxyRunning { get; set; } - /// - /// Manages certificates used by this proxy - /// - private CertificateManager certificateManager { get; set; } - /// /// An default exception log func /// @@ -64,13 +59,18 @@ private FireFoxProxySettingsManager firefoxProxySettingsManager /// public int BufferSize { get; set; } = 8192; + /// + /// Manages certificates used by this proxy + /// + public CertificateManager CertificateManager { get; } + /// /// The root certificate /// public X509Certificate2 RootCertificate { - get { return certificateManager.RootCertificate; } - set { certificateManager.RootCertificate = value; } + get { return CertificateManager.RootCertificate; } + set { CertificateManager.RootCertificate = value; } } /// @@ -79,8 +79,8 @@ public X509Certificate2 RootCertificate /// public string RootCertificateIssuerName { - get { return certificateManager.Issuer; } - set { certificateManager.RootCertificateName = value; } + get { return CertificateManager.Issuer; } + set { CertificateManager.Issuer = value; } } /// @@ -92,8 +92,8 @@ public string RootCertificateIssuerName /// public string RootCertificateName { - get { return certificateManager.RootCertificateName; } - set { certificateManager.Issuer = value; } + get { return CertificateManager.RootCertificateName; } + set { CertificateManager.RootCertificateName = value; } } /// @@ -121,8 +121,8 @@ public bool TrustRootCertificate /// public CertificateEngine CertificateEngine { - get { return certificateManager.Engine; } - set { certificateManager.Engine = value; } + get { return CertificateManager.Engine; } + set { CertificateManager.Engine = value; } } /// @@ -226,6 +226,17 @@ public Action ExceptionFunc public SslProtocols SupportedSslProtocols { get; set; } = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Ssl3; + /// + /// Total number of active client connections + /// + public int ClientConnectionCount { get; private set; } + + + /// + /// Total number of active server connections + /// + public int ServerConnectionCount { get; internal set; } + /// /// Constructor /// @@ -241,7 +252,7 @@ public ProxyServer() : this(null, null) public ProxyServer(string rootCertificateName, string rootCertificateIssuerName) { //default values - ConnectionTimeOutSeconds = 120; + ConnectionTimeOutSeconds = 30; CertificateCacheTimeOutMinutes = 60; ProxyEndPoints = new List(); @@ -251,7 +262,7 @@ public ProxyServer(string rootCertificateName, string rootCertificateIssuerName) new FireFoxProxySettingsManager(); #endif - certificateManager = new CertificateManager(ExceptionFunc); + CertificateManager = new CertificateManager(ExceptionFunc); if (rootCertificateName != null) { RootCertificateName = rootCertificateName; @@ -356,7 +367,7 @@ public void SetAsSystemHttpsProxy(ExplicitProxyEndPoint endPoint) EnsureRootCertificate(); //If certificate was trusted by the machine - if (certificateManager.CertValidated) + if (CertificateManager.CertValidated) { systemProxySettingsManager.SetHttpsProxy( Equals(endPoint.IpAddress, IPAddress.Any) | @@ -435,7 +446,7 @@ public void Start() Listen(endPoint); } - certificateManager.ClearIdleCertificates(CertificateCacheTimeOutMinutes); + CertificateManager.ClearIdleCertificates(CertificateCacheTimeOutMinutes); proxyRunning = true; } @@ -468,7 +479,7 @@ public void Stop() ProxyEndPoints.Clear(); - certificateManager?.StopClearIdleCertificates(); + CertificateManager?.StopClearIdleCertificates(); proxyRunning = false; } @@ -483,7 +494,7 @@ public void Dispose() Stop(); } - certificateManager?.Dispose(); + CertificateManager?.Dispose(); } /// @@ -495,7 +506,7 @@ private void Listen(ProxyEndPoint endPoint) endPoint.Listener = new TcpListener(endPoint.IpAddress, endPoint.Port); endPoint.Listener.Start(); - endPoint.Port = ((IPEndPoint) endPoint.Listener.LocalEndpoint).Port; + endPoint.Port = ((IPEndPoint)endPoint.Listener.LocalEndpoint).Port; // accept clients asynchronously endPoint.Listener.BeginAcceptTcpClient(OnAcceptConnection, endPoint); } @@ -543,13 +554,13 @@ private Task GetSystemUpStreamProxy(SessionEventArgs sessionEvent private void EnsureRootCertificate() { - if (!certificateManager.CertValidated) + if (!CertificateManager.CertValidated) { - certificateManager.CreateTrustedRootCertificate(); + CertificateManager.CreateTrustedRootCertificate(); if (TrustRootCertificate) { - certificateManager.TrustRootCertificate(); + CertificateManager.TrustRootCertificate(); } } } @@ -560,7 +571,7 @@ private void EnsureRootCertificate() /// private void OnAcceptConnection(IAsyncResult asyn) { - var endPoint = (ProxyEndPoint) asyn.AsyncState; + var endPoint = (ProxyEndPoint)asyn.AsyncState; TcpClient tcpClient = null; @@ -581,11 +592,18 @@ private void OnAcceptConnection(IAsyncResult asyn) //Other errors are discarded to keep proxy running } - if (tcpClient != null) { Task.Run(async () => { + ClientConnectionCount++; + + //This line is important! + //contributors please don't remove it without discussion + //It helps to avoid eventual deterioration of performance due to TCP port exhaustion + //due to default TCP CLOSE_WAIT timeout for 4 minutes + tcpClient.LingerState = new LingerOption(true, 0); + try { if (endPoint.GetType() == typeof(TransparentProxyEndPoint)) @@ -599,15 +617,8 @@ private void OnAcceptConnection(IAsyncResult asyn) } finally { - if (tcpClient != null) - { - //This line is important! - //contributors please don't remove it without discussion - //It helps to avoid eventual deterioration of performance due to TCP port exhaustion - //due to default TCP CLOSE_WAIT timeout for 4 minutes - tcpClient.LingerState = new LingerOption(true, 0); - tcpClient.Close(); - } + ClientConnectionCount--; + tcpClient?.Close(); } }); } diff --git a/Titanium.Web.Proxy/RequestHandler.cs b/Titanium.Web.Proxy/RequestHandler.cs index b4fc6fa0c..60f3975e8 100644 --- a/Titanium.Web.Proxy/RequestHandler.cs +++ b/Titanium.Web.Proxy/RequestHandler.cs @@ -33,9 +33,10 @@ private async Task HandleClient(ExplicitProxyEndPoint endPoint, TcpClient tcpCli clientStream.WriteTimeout = ConnectionTimeOutSeconds * 1000; var clientStreamReader = new CustomBinaryReader(clientStream, BufferSize); - var clientStreamWriter = new StreamWriter(clientStream) {NewLine = ProxyConstants.NewLine}; + var clientStreamWriter = new StreamWriter(clientStream) { NewLine = ProxyConstants.NewLine }; Uri httpRemoteUri; + try { //read the first line HTTP command @@ -56,14 +57,14 @@ private async Task HandleClient(ExplicitProxyEndPoint endPoint, TcpClient tcpCli httpRemoteUri = httpVerb == "CONNECT" ? new Uri("http://" + httpCmdSplit[1]) : new Uri(httpCmdSplit[1]); //parse the HTTP version - var version = new Version(1, 1); + var version = HttpHeader.Version11; if (httpCmdSplit.Length == 3) { var httpVersion = httpCmdSplit[2].Trim(); - if (0 == string.CompareOrdinal(httpVersion, "http/1.0")) + if (string.Equals(httpVersion, "HTTP/1.0", StringComparison.OrdinalIgnoreCase)) { - version = new Version(1, 0); + version = HttpHeader.Version10; } } @@ -86,8 +87,8 @@ private async Task HandleClient(ExplicitProxyEndPoint endPoint, TcpClient tcpCli if (httpVerb == "CONNECT" && !excluded && httpRemoteUri.Port != 80) { httpRemoteUri = new Uri("https://" + httpCmdSplit[1]); - string tmpLine; connectRequestHeaders = new List(); + string tmpLine; while (!string.IsNullOrEmpty(tmpLine = await clientStreamReader.ReadLineAsync())) { var header = tmpLine.Split(ProxyConstants.ColonSplit, 2); @@ -110,7 +111,7 @@ private async Task HandleClient(ExplicitProxyEndPoint endPoint, TcpClient tcpCli { sslStream = new SslStream(clientStream, true); - var certificate = endPoint.GenericCertificate ?? certificateManager.CreateCertificate(httpRemoteUri.Host, false); + var certificate = endPoint.GenericCertificate ?? CertificateManager.CreateCertificate(httpRemoteUri.Host, false); //Successfully managed to authenticate the client using the fake certificate await sslStream.AuthenticateAsServerAsync(certificate, false, @@ -119,7 +120,7 @@ await sslStream.AuthenticateAsServerAsync(certificate, false, clientStream = new CustomBufferedStream(sslStream, BufferSize); clientStreamReader = new CustomBinaryReader(clientStream, BufferSize); - clientStreamWriter = new StreamWriter(clientStream) {NewLine = ProxyConstants.NewLine}; + clientStreamWriter = new StreamWriter(clientStream) { NewLine = ProxyConstants.NewLine }; } catch { @@ -140,12 +141,11 @@ await sslStream.AuthenticateAsServerAsync(certificate, false, //write back successfull CONNECT response await WriteConnectResponse(clientStreamWriter, version); - await TcpHelper.SendRaw(BufferSize, ConnectionTimeOutSeconds, httpRemoteUri.Host, httpRemoteUri.Port, + await TcpHelper.SendRaw(this, + httpRemoteUri.Host, httpRemoteUri.Port, null, version, null, - false, SupportedSslProtocols, - ValidateServerCertificate, - SelectClientCertificate, - clientStream, tcpConnectionFactory, UpStreamEndPoint); + false, + clientStream, tcpConnectionFactory); Dispose(clientStream, clientStreamReader, clientStreamWriter, null); return; @@ -156,7 +156,9 @@ await HandleHttpSessionRequest(tcpClient, httpCmd, clientStream, clientStreamRea } catch (Exception) { - Dispose(clientStream, clientStreamReader, clientStreamWriter, null); + Dispose(clientStream, + clientStreamReader, + clientStreamWriter, null); } } @@ -177,7 +179,7 @@ private async Task HandleClient(TransparentProxyEndPoint endPoint, TcpClient tcp var sslStream = new SslStream(clientStream, true); //implement in future once SNI supported by SSL stream, for now use the same certificate - var certificate = certificateManager.CreateCertificate(endPoint.GenericCertificateName, false); + var certificate = CertificateManager.CreateCertificate(endPoint.GenericCertificateName, false); try { @@ -187,21 +189,24 @@ await sslStream.AuthenticateAsServerAsync(certificate, false, clientStream = new CustomBufferedStream(sslStream, BufferSize); clientStreamReader = new CustomBinaryReader(clientStream, BufferSize); - clientStreamWriter = new StreamWriter(clientStream) {NewLine = ProxyConstants.NewLine}; + clientStreamWriter = new StreamWriter(clientStream) { NewLine = ProxyConstants.NewLine }; //HTTPS server created - we can now decrypt the client's traffic } catch (Exception) { sslStream.Dispose(); - Dispose(sslStream, clientStreamReader, clientStreamWriter, null); + Dispose(sslStream, + clientStreamReader, + clientStreamWriter, null); + return; } } else { clientStreamReader = new CustomBinaryReader(clientStream, BufferSize); - clientStreamWriter = new StreamWriter(clientStream) {NewLine = ProxyConstants.NewLine}; + clientStreamWriter = new StreamWriter(clientStream) { NewLine = ProxyConstants.NewLine }; } //now read the request line @@ -212,45 +217,62 @@ await HandleHttpSessionRequest(tcpClient, httpCmd, clientStream, clientStreamRea endPoint.EnableSsl ? endPoint.GenericCertificateName : null, endPoint, null); } - private async Task HandleHttpSessionRequestInternal(TcpConnection connection, SessionEventArgs args, ExternalProxy customUpStreamHttpProxy, ExternalProxy customUpStreamHttpsProxy, bool closeConnection) + /// + /// Create a Server Connection + /// + /// + /// + private async Task GetServerConnection( + SessionEventArgs args) { - try + ExternalProxy customUpStreamHttpProxy = null; + ExternalProxy customUpStreamHttpsProxy = null; + + if (args.WebSession.Request.RequestUri.Scheme == "http") { - if (connection == null) + if (GetCustomUpStreamHttpProxyFunc != null) { - if (args.WebSession.Request.RequestUri.Scheme == "http") - { - if (GetCustomUpStreamHttpProxyFunc != null) - { - customUpStreamHttpProxy = await GetCustomUpStreamHttpProxyFunc(args); - } - } - else - { - if (GetCustomUpStreamHttpsProxyFunc != null) - { - customUpStreamHttpsProxy = await GetCustomUpStreamHttpsProxyFunc(args); - } - } + customUpStreamHttpProxy = await GetCustomUpStreamHttpProxyFunc(args); + } + } + else + { + if (GetCustomUpStreamHttpsProxyFunc != null) + { + customUpStreamHttpsProxy = await GetCustomUpStreamHttpsProxyFunc(args); + } + } - args.CustomUpStreamHttpProxyUsed = customUpStreamHttpProxy; - args.CustomUpStreamHttpsProxyUsed = customUpStreamHttpsProxy; + args.CustomUpStreamHttpProxyUsed = customUpStreamHttpProxy; + args.CustomUpStreamHttpsProxyUsed = customUpStreamHttpsProxy; + + return await tcpConnectionFactory.CreateClient(this, + args.WebSession.Request.RequestUri.Host, + args.WebSession.Request.RequestUri.Port, + args.WebSession.Request.HttpVersion, + args.IsHttps, + customUpStreamHttpProxy ?? UpStreamHttpProxy, + customUpStreamHttpsProxy ?? UpStreamHttpsProxy, + args.ProxyClient.ClientStream); + } - connection = await tcpConnectionFactory.CreateClient(BufferSize, ConnectionTimeOutSeconds, - args.WebSession.Request.RequestUri.Host, args.WebSession.Request.RequestUri.Port, args.WebSession.Request.HttpVersion, - args.IsHttps, SupportedSslProtocols, - ValidateServerCertificate, - SelectClientCertificate, - customUpStreamHttpProxy ?? UpStreamHttpProxy, customUpStreamHttpsProxy ?? UpStreamHttpsProxy, args.ProxyClient.ClientStream, UpStreamEndPoint); - } + private async Task HandleHttpSessionRequestInternal(TcpConnection connection, + SessionEventArgs args, bool closeConnection) + { + try + { args.WebSession.Request.RequestLocked = true; //If request was cancelled by user then dispose the client if (args.WebSession.Request.CancelRequest) { - Dispose(args.ProxyClient.ClientStream, args.ProxyClient.ClientStreamReader, args.ProxyClient.ClientStreamWriter, args); - return; + Dispose(args.ProxyClient.ClientStream, + args.ProxyClient.ClientStreamReader, + args.ProxyClient.ClientStreamWriter, + args.WebSession.ServerConnection); + + return false; } //if expect continue is enabled then send the headers first @@ -314,28 +336,50 @@ await WriteResponseStatus(args.WebSession.Response.HttpVersion, "417", //If not expectation failed response was returned by server then parse response if (!args.WebSession.Request.ExpectationFailed) { - await HandleHttpSessionResponse(args); + var result = await HandleHttpSessionResponse(args); + + //already disposed inside above method + if (result == false) + { + return false; + } } //if connection is closing exit if (args.WebSession.Response.ResponseKeepAlive == false) { - Dispose(args.ProxyClient.ClientStream, args.ProxyClient.ClientStreamReader, args.ProxyClient.ClientStreamWriter, args); - return; + Dispose(args.ProxyClient.ClientStream, + args.ProxyClient.ClientStreamReader, + args.ProxyClient.ClientStreamWriter, + args.WebSession.ServerConnection); + + return false; } } catch (Exception e) { ExceptionFunc(new ProxyHttpException("Error occured whilst handling session request (internal)", e, args)); - Dispose(args.ProxyClient.ClientStream, args.ProxyClient.ClientStreamReader, args.ProxyClient.ClientStreamWriter, args); - return; + + Dispose(args.ProxyClient.ClientStream, + args.ProxyClient.ClientStreamReader, + args.ProxyClient.ClientStreamWriter, + args.WebSession.ServerConnection); + + return false; } if (closeConnection) { //dispose - connection?.Dispose(); + Dispose(args.ProxyClient.ClientStream, + args.ProxyClient.ClientStreamReader, + args.ProxyClient.ClientStreamWriter, + args.WebSession.ServerConnection); + + return false; } + + return true; } /// @@ -354,7 +398,7 @@ await WriteResponseStatus(args.WebSession.Response.HttpVersion, "417", /// private async Task HandleHttpSessionRequest(TcpClient client, string httpCmd, Stream clientStream, CustomBinaryReader clientStreamReader, StreamWriter clientStreamWriter, string httpsHostName, - ProxyEndPoint endPoint, List connectHeaders, ExternalProxy customUpStreamHttpProxy = null, ExternalProxy customUpStreamHttpsProxy = null) + ProxyEndPoint endPoint, List connectHeaders) { TcpConnection connection = null; @@ -364,20 +408,23 @@ private async Task HandleHttpSessionRequest(TcpClient client, string httpCmd, St { if (string.IsNullOrEmpty(httpCmd)) { - Dispose(clientStream, clientStreamReader, clientStreamWriter, null); + Dispose(clientStream, + clientStreamReader, + clientStreamWriter, + connection); + break; } - var args = - new SessionEventArgs(BufferSize, HandleHttpSessionResponse) + var args = new SessionEventArgs(BufferSize, HandleHttpSessionResponse) { - ProxyClient = {TcpClient = client}, - WebSession = {ConnectHeaders = connectHeaders} + ProxyClient = { TcpClient = client }, + WebSession = { ConnectHeaders = connectHeaders } }; args.WebSession.ProcessId = new Lazy(() => { - var remoteEndPoint = (IPEndPoint) args.ProxyClient.TcpClient.Client.RemoteEndPoint; + var remoteEndPoint = (IPEndPoint)args.ProxyClient.TcpClient.Client.RemoteEndPoint; //If client is localhost get the process id if (NetworkHelper.IsLocalIpAddress(remoteEndPoint.Address)) @@ -388,6 +435,7 @@ private async Task HandleHttpSessionRequest(TcpClient client, string httpCmd, St //can't access process Id of remote request from remote machine return -1; }); + try { //break up the line into three components (method, remote URL & Http Version) @@ -396,51 +444,22 @@ private async Task HandleHttpSessionRequest(TcpClient client, string httpCmd, St var httpMethod = httpCmdSplit[0]; //find the request HTTP version - var httpVersion = new Version(1, 1); + var httpVersion = HttpHeader.Version11; if (httpCmdSplit.Length == 3) { - var httpVersionString = httpCmdSplit[2].ToLower().Trim(); + var httpVersionString = httpCmdSplit[2].Trim(); - if (0 == string.CompareOrdinal(httpVersionString, "http/1.0")) + if (string.Equals(httpVersionString, "HTTP/1.0", StringComparison.OrdinalIgnoreCase)) { - httpVersion = new Version(1, 0); + httpVersion = HttpHeader.Version10; } } - //Read the request headers in to unique and non-unique header collections - string tmpLine; - while (!string.IsNullOrEmpty(tmpLine = await clientStreamReader.ReadLineAsync())) - { - var header = tmpLine.Split(ProxyConstants.ColonSplit, 2); - - var newHeader = new HttpHeader(header[0], header[1]); - - //if header exist in non-unique header collection add it there - if (args.WebSession.Request.NonUniqueRequestHeaders.ContainsKey(newHeader.Name)) - { - args.WebSession.Request.NonUniqueRequestHeaders[newHeader.Name].Add(newHeader); - } - //if header is alread in unique header collection then move both to non-unique collection - else if (args.WebSession.Request.RequestHeaders.ContainsKey(newHeader.Name)) - { - var existing = args.WebSession.Request.RequestHeaders[newHeader.Name]; - - var nonUniqueHeaders = new List {existing, newHeader}; - - - args.WebSession.Request.NonUniqueRequestHeaders.Add(newHeader.Name, nonUniqueHeaders); - args.WebSession.Request.RequestHeaders.Remove(newHeader.Name); - } - //add to unique header collection - else - { - args.WebSession.Request.RequestHeaders.Add(newHeader.Name, newHeader); - } - } + await HeaderParser.ReadHeaders(clientStreamReader, args.WebSession.Request.NonUniqueRequestHeaders, args.WebSession.Request.RequestHeaders); var httpRemoteUri = new Uri(httpsHostName == null ? httpCmdSplit[1] - : (string.Concat("https://", args.WebSession.Request.Host ?? httpsHostName, httpCmdSplit[1]))); + : string.Concat("https://", args.WebSession.Request.Host ?? httpsHostName, httpCmdSplit[1])); args.WebSession.Request.RequestUri = httpRemoteUri; @@ -450,9 +469,15 @@ private async Task HandleHttpSessionRequest(TcpClient client, string httpCmd, St args.ProxyClient.ClientStreamReader = clientStreamReader; args.ProxyClient.ClientStreamWriter = clientStreamWriter; - if (httpsHostName == null && (await CheckAuthorization(clientStreamWriter, args.WebSession.Request.RequestHeaders.Values) == false)) + if (httpsHostName == null && + await CheckAuthorization(clientStreamWriter, + args.WebSession.Request.RequestHeaders.Values) == false) { - Dispose(clientStream, clientStreamReader, clientStreamWriter, args); + Dispose(clientStream, + clientStreamReader, + clientStreamWriter, + connection); + break; } @@ -467,7 +492,7 @@ private async Task HandleHttpSessionRequest(TcpClient client, string httpCmd, St for (var i = 0; i < invocationList.Length; i++) { - handlerTasks[i] = ((Func) invocationList[i])(null, args); + handlerTasks[i] = ((Func)invocationList[i])(this, args); } await Task.WhenAll(handlerTasks); @@ -476,28 +501,51 @@ private async Task HandleHttpSessionRequest(TcpClient client, string httpCmd, St //if upgrading to websocket then relay the requet without reading the contents if (args.WebSession.Request.UpgradeToWebSocket) { - await TcpHelper.SendRaw(BufferSize, ConnectionTimeOutSeconds, httpRemoteUri.Host, httpRemoteUri.Port, + await TcpHelper.SendRaw(this, + httpRemoteUri.Host, httpRemoteUri.Port, httpCmd, httpVersion, args.WebSession.Request.RequestHeaders, args.IsHttps, - SupportedSslProtocols, ValidateServerCertificate, - SelectClientCertificate, - clientStream, tcpConnectionFactory, UpStreamEndPoint); + clientStream, tcpConnectionFactory); + + Dispose(clientStream, + clientStreamReader, + clientStreamWriter, + connection); - Dispose(clientStream, clientStreamReader, clientStreamWriter, args); break; } + if (connection == null) + { + connection = await GetServerConnection(args); + } + //construct the web request that we are going to issue on behalf of the client. - await HandleHttpSessionRequestInternal(null, args, customUpStreamHttpProxy, customUpStreamHttpsProxy, false); + var result = await HandleHttpSessionRequestInternal(connection, args, false); + if (result == false) + { + //already disposed inside above method + break; + } if (args.WebSession.Request.CancelRequest) { + Dispose(clientStream, + clientStreamReader, + clientStreamWriter, + connection); + break; } //if connection is closing exit if (args.WebSession.Response.ResponseKeepAlive == false) { + Dispose(clientStream, + clientStreamReader, + clientStreamWriter, + connection); + break; } @@ -507,13 +555,15 @@ await TcpHelper.SendRaw(BufferSize, ConnectionTimeOutSeconds, httpRemoteUri.Host catch (Exception e) { ExceptionFunc(new ProxyHttpException("Error occured whilst handling session request", e, args)); - Dispose(clientStream, clientStreamReader, clientStreamWriter, args); + + Dispose(clientStream, + clientStreamReader, + clientStreamWriter, + connection); break; } } - //dispose - connection?.Dispose(); } /// diff --git a/Titanium.Web.Proxy/ResponseHandler.cs b/Titanium.Web.Proxy/ResponseHandler.cs index e3fce3059..0fadbbee0 100644 --- a/Titanium.Web.Proxy/ResponseHandler.cs +++ b/Titanium.Web.Proxy/ResponseHandler.cs @@ -9,6 +9,7 @@ using Titanium.Web.Proxy.Extensions; using Titanium.Web.Proxy.Http; using Titanium.Web.Proxy.Helpers; +using Titanium.Web.Proxy.Network.Tcp; namespace Titanium.Web.Proxy { @@ -21,14 +22,14 @@ partial class ProxyServer /// Called asynchronously when a request was successfully and we received the response /// /// - /// - private async Task HandleHttpSessionResponse(SessionEventArgs args) + /// true if no errors + private async Task HandleHttpSessionResponse(SessionEventArgs args) { - //read response & headers from server - await args.WebSession.ReceiveResponse(); - try { + //read response & headers from server + await args.WebSession.ReceiveResponse(); + if (!args.WebSession.Response.ResponseBodyRead) { args.WebSession.Response.ResponseStream = args.WebSession.ServerConnection.Stream; @@ -44,7 +45,7 @@ private async Task HandleHttpSessionResponse(SessionEventArgs args) for (int i = 0; i < invocationList.Length; i++) { - handlerTasks[i] = ((Func) invocationList[i])(this, args); + handlerTasks[i] = ((Func)invocationList[i])(this, args); } await Task.WhenAll(handlerTasks); @@ -52,8 +53,15 @@ private async Task HandleHttpSessionResponse(SessionEventArgs args) if (args.ReRequest) { - await HandleHttpSessionRequestInternal(null, args, null, null, true); - return; + if(args.WebSession.ServerConnection != null) + { + args.WebSession.ServerConnection.Dispose(); + ServerConnectionCount--; + } + + var connection = await GetServerConnection(args); + var result = await HandleHttpSessionRequestInternal(null, args, true); + return result; } args.WebSession.Response.ResponseLocked = true; @@ -125,14 +133,16 @@ await args.WebSession.ServerConnection.StreamReader } catch (Exception e) { - ExceptionFunc(new ProxyHttpException("Error occured wilst handling session response", e, args)); + ExceptionFunc(new ProxyHttpException("Error occured whilst handling session response", e, args)); Dispose(args.ProxyClient.ClientStream, args.ProxyClient.ClientStreamReader, - args.ProxyClient.ClientStreamWriter, args); - } - finally - { - args.Dispose(); + args.ProxyClient.ClientStreamWriter, args.WebSession.ServerConnection); + + return false; } + + args.Dispose(); + + return true; } /// @@ -219,24 +229,26 @@ private void FixProxyHeaders(Dictionary headers) } /// - /// Handle dispose of a client/server session + /// Handle dispose of a client/server session /// /// /// /// - /// - private void Dispose(Stream clientStream, CustomBinaryReader clientStreamReader, - StreamWriter clientStreamWriter, IDisposable args) + /// + private void Dispose(Stream clientStream, + CustomBinaryReader clientStreamReader, + StreamWriter clientStreamWriter, + TcpConnection serverConnection) { + ServerConnectionCount--; + clientStream?.Close(); clientStream?.Dispose(); clientStreamReader?.Dispose(); - - clientStreamWriter?.Close(); clientStreamWriter?.Dispose(); - args?.Dispose(); + serverConnection?.Dispose(); } } } diff --git a/Titanium.Web.Proxy/Titanium.Web.Proxy.csproj b/Titanium.Web.Proxy/Titanium.Web.Proxy.csproj index c3eb67485..6b186297c 100644 --- a/Titanium.Web.Proxy/Titanium.Web.Proxy.csproj +++ b/Titanium.Web.Proxy/Titanium.Web.Proxy.csproj @@ -77,6 +77,7 @@ +