Skip to content

Commit

Permalink
Unix Domain Socket WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
caleblloyd committed Oct 22, 2016
1 parent 3f92919 commit 1ac303e
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 31 deletions.
30 changes: 26 additions & 4 deletions src/MySqlConnector/Serialization/ConnectionSettings.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using MySql.Data.MySqlClient;
Expand All @@ -14,8 +15,19 @@ public ConnectionSettings(MySqlConnectionStringBuilder csb)
ConnectionString = csb.ConnectionString;

// Base Options
Hostnames = csb.Server.Split(',');
Port = (int) csb.Port;
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && (csb.Server.StartsWith("/") || csb.Server.StartsWith("./")))
{
if (!File.Exists(csb.Server))
throw new MySqlException("Cannot find Unix Socket at " + csb.Server);
ConnectionType = ConnectionType.Unix;
UnixSocket = Path.GetFullPath(csb.Server);
}
else
{
ConnectionType = ConnectionType.Tcp;
Hostnames = csb.Server.Split(',');
Port = (int) csb.Port;
}
UserID = csb.UserID;
Password = csb.Password;
Database = csb.Database;
Expand Down Expand Up @@ -59,18 +71,27 @@ public ConnectionSettings(MySqlConnectionStringBuilder csb)

private ConnectionSettings(ConnectionSettings other, bool? useCompression)
{
// Base Options
ConnectionString = other.ConnectionString;
ConnectionType = other.ConnectionType;
Hostnames = other.Hostnames;
Port = other.Port;
UnixSocket = other.UnixSocket;
UserID = other.UserID;
Password = other.Password;
Database = other.Database;

// SSL/TLS Options
SslMode = other.SslMode;
Certificate = other.Certificate;

// Connection Pooling Options
Pooling = other.Pooling;
ConnectionReset = other.ConnectionReset;
MinimumPoolSize = other.MinimumPoolSize;
MaximumPoolSize = other.MaximumPoolSize;

// Other Options
AllowUserVariables = other.AllowUserVariables;
ConnectionTimeout = other.ConnectionTimeout;
ConvertZeroDateTime = other.ConvertZeroDateTime;
Expand All @@ -81,11 +102,12 @@ private ConnectionSettings(ConnectionSettings other, bool? useCompression)
UseCompression = useCompression ?? other.UseCompression;
}

internal readonly string ConnectionString;

// Base Options
internal readonly string ConnectionString;
internal readonly ConnectionType ConnectionType;
internal readonly IEnumerable<string> Hostnames;
internal readonly int Port;
internal readonly string UnixSocket;
internal readonly string UserID;
internal readonly string Password;
internal readonly string Database;
Expand Down
18 changes: 18 additions & 0 deletions src/MySqlConnector/Serialization/ConnectionType.cs
@@ -0,0 +1,18 @@
namespace MySql.Data.Serialization
{
/// <summary>
/// Specifies whether to perform synchronous or asynchronous I/O.
/// </summary>
internal enum ConnectionType
{
/// <summary>
/// Connection is a TCP connection.
/// </summary>
Tcp,

/// <summary>
/// Connection is a Unix Domain Socket.
/// </summary>
Unix,
}
}
125 changes: 98 additions & 27 deletions src/MySqlConnector/Serialization/MySqlSession.cs
Expand Up @@ -3,6 +3,7 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
Expand Down Expand Up @@ -49,32 +50,24 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
{
// socket may have been closed during shutdown; ignore
}
m_payloadHandler = null;
}
if (m_tcpClient != null)
{
try
{
#if NETSTANDARD1_3
m_tcpClient.Dispose();
#else
m_tcpClient.Close();
#endif
}
catch (SocketException)
{
}
m_tcpClient = null;
}
ShutdownSocket();
m_state = State.Closed;
}

public async Task ConnectAsync(ConnectionSettings cs, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
var connected = await OpenSocketAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
var connected = false;
if (cs.ConnectionType == ConnectionType.Tcp)
connected = await OpenTcpSocketAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
else if (cs.ConnectionType == ConnectionType.Unix)
connected = await OpenUnixSocketAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
if (!connected)
throw new MySqlException("Unable to connect to any of the specified MySQL hosts.");

var socketByteHandler = new SocketByteHandler(m_socket);
m_payloadHandler = new StandardPayloadHandler(socketByteHandler);

var payload = await ReceiveAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
var reader = new ByteArrayReader(payload.ArraySegment.Array, payload.ArraySegment.Offset, payload.ArraySegment.Count);
var initialHandshake = new InitialHandshakePacket(reader);
Expand Down Expand Up @@ -184,7 +177,7 @@ private void VerifyConnected()
throw new InvalidOperationException("MySqlSession is not connected.");
}

private async Task<bool> OpenSocketAsync(ConnectionSettings cs, IOBehavior ioBehavior, CancellationToken cancellationToken)
private async Task<bool> OpenTcpSocketAsync(ConnectionSettings cs, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
foreach (var hostname in cs.Hostnames)
{
Expand Down Expand Up @@ -249,9 +242,8 @@ private async Task<bool> OpenSocketAsync(ConnectionSettings cs, IOBehavior ioBeh

m_hostname = hostname;
m_tcpClient = tcpClient;

var socketByteHandler = new SocketByteHandler(m_tcpClient.Client);
m_payloadHandler = new StandardPayloadHandler(socketByteHandler);
m_socket = m_tcpClient.Client;
m_networkStream = m_tcpClient.GetStream();

m_state = State.Connected;
return true;
Expand All @@ -260,6 +252,56 @@ private async Task<bool> OpenSocketAsync(ConnectionSettings cs, IOBehavior ioBeh
return false;
}

private async Task<bool> OpenUnixSocketAsync(ConnectionSettings cs, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.IP);
var unixEp = new UnixEndPoint(cs.UnixSocket);
try
{
using (cancellationToken.Register(() => socket.Dispose()))
{
try
{
if (ioBehavior == IOBehavior.Asynchronous)
{
#if NETSTANDARD1_3
socket.Connect(unixEp);
#else
await Task.Factory.FromAsync(socket.BeginConnect, socket.EndConnect, unixEp, null).ConfigureAwait(false);
#endif
}
else
{
#if NETSTANDARD1_3
await socket.ConnectAsync(unixEp).ConfigureAwait(false);
#else
socket.Connect(unixEp);
#endif
}
}
catch (ObjectDisposedException ex) when (cancellationToken.IsCancellationRequested)
{
throw new MySqlException("Connect Timeout expired.", ex);
}
}
}
catch (SocketException)
{
socket.Dispose();
}

if (socket.Connected)
{
m_socket = socket;
m_networkStream = new NetworkStream(socket);

m_state = State.Connected;
return true;
}

return false;
}

private async Task InitSslAsync(ConnectionSettings cs, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
Func<object, string, X509CertificateCollection, X509Certificate, string[], X509Certificate> localCertificateCb =
Expand All @@ -279,7 +321,7 @@ private async Task InitSslAsync(ConnectionSettings cs, IOBehavior ioBehavior, Ca
}
};

var sslStream = new SslStream(m_tcpClient.GetStream(), false,
var sslStream = new SslStream(m_networkStream, false,
new RemoteCertificateValidationCallback(remoteCertificateCb),
new LocalCertificateSelectionCallback(localCertificateCb));
var clientCertificates = new X509CertificateCollection { cs.Certificate };
Expand Down Expand Up @@ -313,16 +355,42 @@ private async Task InitSslAsync(ConnectionSettings cs, IOBehavior ioBehavior, Ca
}
catch (AuthenticationException ex)
{
ShutdownSocket();
m_hostname = "";
m_state = State.Failed;
throw new MySqlException("SSL Authentication Error", ex);
}
}

private void ShutdownSocket()
{
m_payloadHandler = null;
if (m_tcpClient != null)
{
try
{
#if NETSTANDARD1_3
m_tcpClient.Dispose();
#else
m_tcpClient.Close();
m_tcpClient.Close();
#endif
m_hostname = "";
m_payloadHandler = null;
m_state = State.Failed;
}
catch (SocketException)
{
}
m_tcpClient = null;
throw new MySqlException("SSL Authentication Error", ex);
m_socket = null;
}
else if (m_socket != null)
{
try
{
m_socket.Dispose();
m_socket = null;
}
catch (SocketException)
{
}
}
}

Expand Down Expand Up @@ -389,7 +457,10 @@ private enum State

State m_state;
string m_hostname;

TcpClient m_tcpClient;
Socket m_socket;
NetworkStream m_networkStream;
IPayloadHandler m_payloadHandler;
}
}
106 changes: 106 additions & 0 deletions src/MySqlConnector/UnixEndPoint.cs
@@ -0,0 +1,106 @@
using System;
using System.Net;
using System.Net.Sockets;
using System.Text;

namespace MySql.Data
{
// copied from https://github.com/mono/mono/blob/master/mcs/class/Mono.Posix/Mono.Unix/UnixEndPoint.cs
#if NETSTANDARD1_3
#else
[Serializable]
#endif
public class UnixEndPoint : EndPoint
{
string filename;

public UnixEndPoint (string filename)
{
if (filename == null)
throw new ArgumentNullException ("filename");

if (filename == "")
throw new ArgumentException ("Cannot be empty.", "filename");
this.filename = filename;
}

public string Filename {
get {
return(filename);
}
set {
filename=value;
}
}

public override AddressFamily AddressFamily {
get { return AddressFamily.Unix; }
}

public override EndPoint Create (SocketAddress socketAddress)
{
/*
* Should also check this
*
int addr = (int) AddressFamily.Unix;
if (socketAddress [0] != (addr & 0xFF))
throw new ArgumentException ("socketAddress is not a unix socket address.");
if (socketAddress [1] != ((addr & 0xFF00) >> 8))
throw new ArgumentException ("socketAddress is not a unix socket address.");
*/

if (socketAddress.Size == 2) {
// Empty filename.
// Probably from RemoteEndPoint which on linux does not return the file name.
UnixEndPoint uep = new UnixEndPoint ("a");
uep.filename = "";
return uep;
}
int size = socketAddress.Size - 2;
byte [] bytes = new byte [size];
for (int i = 0; i < bytes.Length; i++) {
bytes [i] = socketAddress [i + 2];
// There may be junk after the null terminator, so ignore it all.
if (bytes [i] == 0) {
size = i;
break;
}
}

string name = Encoding.UTF8.GetString (bytes, 0, size);
return new UnixEndPoint (name);
}

public override SocketAddress Serialize ()
{
byte [] bytes = Encoding.UTF8.GetBytes (filename);
SocketAddress sa = new SocketAddress (AddressFamily, 2 + bytes.Length + 1);
// sa [0] -> family low byte, sa [1] -> family high byte
for (int i = 0; i < bytes.Length; i++)
sa [2 + i] = bytes [i];

//NULL suffix for non-abstract path
sa[2 + bytes.Length] = 0;

return sa;
}

public override string ToString() {
return(filename);
}

public override int GetHashCode ()
{
return filename.GetHashCode ();
}

public override bool Equals (object o)
{
UnixEndPoint other = o as UnixEndPoint;
if (other == null)
return false;

return (other.filename == filename);
}
}
}

0 comments on commit 1ac303e

Please sign in to comment.