From fb81de3548537e8777b79e563eb13562602afd5f Mon Sep 17 00:00:00 2001 From: rusher Date: Mon, 3 Sep 2018 02:27:23 -0700 Subject: [PATCH] [CONJ-634] avoid import of non dependency to permit compilation Ahead of time --- .../com/send/SendGssApiAuthPacket.java | 41 ++++------ .../com/send/SendHandshakeResponsePacket.java | 16 +++- .../internal/com/send/gssapi/GssUtility.java | 33 +++++++++ .../io/socket/SocketHandlerFunction.java | 11 +++ .../internal/io/socket/SocketUtility.java | 41 ++++++++++ .../jdbc/internal/util/PidFactory.java | 54 +++++--------- .../jdbc/internal/util/PidRequestInter.java | 5 ++ .../org/mariadb/jdbc/internal/util/Utils.java | 74 +++++++++---------- 8 files changed, 174 insertions(+), 101 deletions(-) create mode 100644 src/main/java/org/mariadb/jdbc/internal/com/send/gssapi/GssUtility.java create mode 100644 src/main/java/org/mariadb/jdbc/internal/io/socket/SocketHandlerFunction.java create mode 100644 src/main/java/org/mariadb/jdbc/internal/io/socket/SocketUtility.java create mode 100644 src/main/java/org/mariadb/jdbc/internal/util/PidRequestInter.java diff --git a/src/main/java/org/mariadb/jdbc/internal/com/send/SendGssApiAuthPacket.java b/src/main/java/org/mariadb/jdbc/internal/com/send/SendGssApiAuthPacket.java index 55771a614..f8330116a 100644 --- a/src/main/java/org/mariadb/jdbc/internal/com/send/SendGssApiAuthPacket.java +++ b/src/main/java/org/mariadb/jdbc/internal/com/send/SendGssApiAuthPacket.java @@ -52,12 +52,11 @@ package org.mariadb.jdbc.internal.com.send; -import com.sun.jna.Platform; import org.mariadb.jdbc.internal.com.read.Buffer; import org.mariadb.jdbc.internal.com.read.ErrorPacket; +import org.mariadb.jdbc.internal.com.send.gssapi.GssUtility; import org.mariadb.jdbc.internal.com.send.gssapi.GssapiAuth; import org.mariadb.jdbc.internal.com.send.gssapi.StandardGssapiAuthentication; -import org.mariadb.jdbc.internal.com.send.gssapi.WindowsNativeSspiAuthentication; import org.mariadb.jdbc.internal.io.input.PacketInputStream; import org.mariadb.jdbc.internal.io.output.PacketOutputStream; @@ -65,11 +64,24 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.sql.SQLException; +import java.util.function.BiFunction; import static org.mariadb.jdbc.internal.com.Packet.ERROR; public class SendGssApiAuthPacket extends AbstractAuthSwitchSendResponsePacket implements InterfaceAuthSwitchSendResponsePacket { private final PacketInputStream reader; + private final static BiFunction gssMethod; + + static { + BiFunction init; + try { + init = GssUtility.getAuthenticationMethod(); + } catch (Throwable t) { + BiFunction defaultAuthenticationMethod = (reader, packSeq) -> new StandardGssapiAuthentication(reader, packSeq); + init = defaultAuthenticationMethod; + } + gssMethod = init; + } public SendGssApiAuthPacket(PacketInputStream reader, String password, byte[] authData, int packSeq, String passwordCharacterEncoding) { super(packSeq, authData, password, passwordCharacterEncoding); @@ -88,7 +100,7 @@ public void send(PacketOutputStream pos) throws IOException, SQLException { String mechanisms = buffer.readStringNullEnd(StandardCharsets.UTF_8); if (mechanisms.isEmpty()) mechanisms = "Kerberos"; - GssapiAuth gssapiAuth = getAuthenticationMethod(); + GssapiAuth gssapiAuth = gssMethod.apply(reader, packSeq); gssapiAuth.authenticate(pos, serverPrincipalName, mechanisms); } @@ -107,28 +119,5 @@ public void handleResultPacket(PacketInputStream reader) throws SQLException, IO } } - /** - * Get authentication method according to classpath. - * Windows native authentication is using Waffle-jna. - * - * @return authentication method - */ - private GssapiAuth getAuthenticationMethod() { - try { - //Waffle-jna has jna as dependency, so if not available on classpath, just use standard authentication - if (Platform.isWindows()) { - try { - Class.forName("waffle.windows.auth.impl.WindowsAuthProviderImpl"); - return new WindowsNativeSspiAuthentication(reader, packSeq); - } catch (ClassNotFoundException cle) { - //waffle not in the classpath - } - } - } catch (Throwable cle) { - //jna jar's are not in classpath - } - return new StandardGssapiAuthentication(reader, packSeq); - } - } diff --git a/src/main/java/org/mariadb/jdbc/internal/com/send/SendHandshakeResponsePacket.java b/src/main/java/org/mariadb/jdbc/internal/com/send/SendHandshakeResponsePacket.java index d28316f18..810fb218c 100644 --- a/src/main/java/org/mariadb/jdbc/internal/com/send/SendHandshakeResponsePacket.java +++ b/src/main/java/org/mariadb/jdbc/internal/com/send/SendHandshakeResponsePacket.java @@ -61,6 +61,7 @@ import org.mariadb.jdbc.internal.protocol.authentication.DefaultAuthenticationProvider; import org.mariadb.jdbc.internal.util.Options; import org.mariadb.jdbc.internal.util.PidFactory; +import org.mariadb.jdbc.internal.util.PidRequestInter; import org.mariadb.jdbc.internal.util.Utils; import org.mariadb.jdbc.internal.util.constant.Version; @@ -92,6 +93,18 @@ */ public class SendHandshakeResponsePacket { + private static final PidRequestInter pidRequest; + + static { + PidRequestInter init; + try { + init = PidFactory.getInstance(); + } catch (Throwable t) { + init = () -> null; + } + pidRequest = init; + } + /** * Send handshake response packet. * @@ -214,8 +227,7 @@ private static void writeConnectAttributes(PacketOutputStream pos, String connec buffer.writeStringSmallLength(_OS); buffer.writeStringLength(System.getProperty("os.name")); - - String pid = PidFactory.getInstance().getPid(); + String pid = pidRequest.getPid(); if (pid != null) { buffer.writeStringSmallLength(_PID); buffer.writeStringLength(pid); diff --git a/src/main/java/org/mariadb/jdbc/internal/com/send/gssapi/GssUtility.java b/src/main/java/org/mariadb/jdbc/internal/com/send/gssapi/GssUtility.java new file mode 100644 index 000000000..801c8035a --- /dev/null +++ b/src/main/java/org/mariadb/jdbc/internal/com/send/gssapi/GssUtility.java @@ -0,0 +1,33 @@ +package org.mariadb.jdbc.internal.com.send.gssapi; + +import com.sun.jna.Platform; +import org.mariadb.jdbc.internal.io.input.PacketInputStream; + +import java.util.function.BiFunction; + +public class GssUtility { + + /** + * Get authentication method according to classpath. + * Windows native authentication is using Waffle-jna. + * + * @return authentication method + */ + public static BiFunction getAuthenticationMethod() { + try { + //Waffle-jna has jna as dependency, so if not available on classpath, just use standard authentication + if (Platform.isWindows()) { + try { + Class.forName("waffle.windows.auth.impl.WindowsAuthProviderImpl"); + return (reader, packSeq) -> new WindowsNativeSspiAuthentication(reader, packSeq); + } catch (ClassNotFoundException cle) { + //waffle not in the classpath + } + } + } catch (Throwable cle) { + //jna jar's are not in classpath + } + return (reader, packSeq) -> new StandardGssapiAuthentication(reader, packSeq); + } + +} diff --git a/src/main/java/org/mariadb/jdbc/internal/io/socket/SocketHandlerFunction.java b/src/main/java/org/mariadb/jdbc/internal/io/socket/SocketHandlerFunction.java new file mode 100644 index 000000000..8117ac31f --- /dev/null +++ b/src/main/java/org/mariadb/jdbc/internal/io/socket/SocketHandlerFunction.java @@ -0,0 +1,11 @@ +package org.mariadb.jdbc.internal.io.socket; + +import org.mariadb.jdbc.UrlParser; + +import java.io.IOException; +import java.net.Socket; + +@FunctionalInterface +public interface SocketHandlerFunction { + Socket apply(UrlParser a, String b) throws IOException; +} \ No newline at end of file diff --git a/src/main/java/org/mariadb/jdbc/internal/io/socket/SocketUtility.java b/src/main/java/org/mariadb/jdbc/internal/io/socket/SocketUtility.java new file mode 100644 index 000000000..0f56800c1 --- /dev/null +++ b/src/main/java/org/mariadb/jdbc/internal/io/socket/SocketUtility.java @@ -0,0 +1,41 @@ +package org.mariadb.jdbc.internal.io.socket; + +import com.sun.jna.Platform; +import org.mariadb.jdbc.internal.util.Utils; + +import java.io.IOException; + +public class SocketUtility { + + @SuppressWarnings("unchecked") + public static SocketHandlerFunction getSocketHandler() { + try { + //forcing use of JNA to ensure AOT compilation + Platform.getOSType(); + + return (urlParser, host) -> { + if (urlParser.getOptions().pipe != null) { + return new NamedPipeSocket(host, urlParser.getOptions().pipe); + } else if (urlParser.getOptions().localSocket != null) { + try { + return new UnixDomainSocket(urlParser.getOptions().localSocket); + } catch (RuntimeException re) { + throw new IOException(re.getMessage(), re.getCause()); + } + } else if (urlParser.getOptions().sharedMemory != null) { + try { + return new SharedMemorySocket(urlParser.getOptions().sharedMemory); + } catch (RuntimeException re) { + throw new IOException(re.getMessage(), re.getCause()); + } + } else { + return Utils.standardSocket(urlParser, host); + } + + }; + } catch (Throwable cle) { + //jna jar's are not in classpath + } + return (urlParser, host) -> Utils.standardSocket(urlParser, host); + } +} diff --git a/src/main/java/org/mariadb/jdbc/internal/util/PidFactory.java b/src/main/java/org/mariadb/jdbc/internal/util/PidFactory.java index c9e1d20fe..0dbc56961 100644 --- a/src/main/java/org/mariadb/jdbc/internal/util/PidFactory.java +++ b/src/main/java/org/mariadb/jdbc/internal/util/PidFactory.java @@ -59,55 +59,39 @@ public class PidFactory { - private static PidRequestInter pidRequest = null; - /** * Factory method to avoid loading JNA classes every connection. * * @return factory that implement PID according to environment. */ public static PidRequestInter getInstance() { - if (pidRequest == null) { - synchronized (PidFactory.class) { - // check again within synchronized block to guard for race condition - if (pidRequest == null) { - //initialize JNA methods - try { - if (Platform.isLinux()) { - //Linux pid implementation - pidRequest = () -> String.valueOf(CLibrary.INSTANCE.getpid()); - } else { - if (Platform.isWindows()) { - //Windows pid implementation - pidRequest = () -> { - try { - return String.valueOf(Kernel32.INSTANCE.GetCurrentProcessId()); - } catch (Throwable cle) { - //jna plateform jar's are not in classpath, no PID returned - } - return null; - }; - - } + try { + if (Platform.isLinux()) { + //Linux pid implementation + return () -> String.valueOf(CLibrary.INSTANCE.getpid()); + } else { + if (Platform.isWindows()) { + //Windows pid implementation + return () -> { + try { + return String.valueOf(Kernel32.INSTANCE.GetCurrentProcessId()); + } catch (Throwable cle) { + //jna plateform jar's are not in classpath, no PID returned } - } catch (Throwable cle) { - //jna jar's are not in classpath, no PID returned - } + return null; + }; - //No JNA, or environment not Linux/windows -> return no PID - if (pidRequest == null) { - pidRequest = () -> null; - } } } + } catch (Throwable cle) { + //jna jar's are not in classpath, no PID returned } - return pidRequest; - } - public interface PidRequestInter { - String getPid(); + //No JNA, or environment not Linux/windows -> return no PID + return () -> null; } + private interface CLibrary extends Library { CLibrary INSTANCE = (CLibrary) Native.loadLibrary("c", CLibrary.class); int getpid(); diff --git a/src/main/java/org/mariadb/jdbc/internal/util/PidRequestInter.java b/src/main/java/org/mariadb/jdbc/internal/util/PidRequestInter.java new file mode 100644 index 000000000..66b8c60f6 --- /dev/null +++ b/src/main/java/org/mariadb/jdbc/internal/util/PidRequestInter.java @@ -0,0 +1,5 @@ +package org.mariadb.jdbc.internal.util; + +public interface PidRequestInter { + String getPid(); +} diff --git a/src/main/java/org/mariadb/jdbc/internal/util/Utils.java b/src/main/java/org/mariadb/jdbc/internal/util/Utils.java index 607257196..f7d887350 100644 --- a/src/main/java/org/mariadb/jdbc/internal/util/Utils.java +++ b/src/main/java/org/mariadb/jdbc/internal/util/Utils.java @@ -57,9 +57,7 @@ import org.mariadb.jdbc.internal.failover.impl.AuroraListener; import org.mariadb.jdbc.internal.failover.impl.MastersFailoverListener; import org.mariadb.jdbc.internal.failover.impl.MastersSlavesListener; -import org.mariadb.jdbc.internal.io.socket.NamedPipeSocket; -import org.mariadb.jdbc.internal.io.socket.SharedMemorySocket; -import org.mariadb.jdbc.internal.io.socket.UnixDomainSocket; +import org.mariadb.jdbc.internal.io.socket.*; import org.mariadb.jdbc.internal.logging.ProtocolLoggingProxy; import org.mariadb.jdbc.internal.protocol.AuroraProtocol; import org.mariadb.jdbc.internal.protocol.MasterProtocol; @@ -93,6 +91,40 @@ public class Utils { private static final Pattern IP_V6_COMPRESSED = Pattern.compile("^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)" + "::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)$"); + private final static SocketHandlerFunction socketHandler; + + static { + SocketHandlerFunction init; + try { + init = SocketUtility.getSocketHandler(); + } catch (Throwable t) { + SocketHandlerFunction defaultSocketHandler = (urlParser, host) -> Utils.standardSocket(urlParser, host); + init = defaultSocketHandler; + } + socketHandler = init; + } + + public static Socket standardSocket(UrlParser urlParser, String host) throws IOException { + SocketFactory socketFactory; + String socketFactoryName = urlParser.getOptions().socketFactory; + if (socketFactoryName != null) { + try { + @SuppressWarnings("unchecked") + Class socketFactoryClass = (Class) Class.forName(socketFactoryName); + if (socketFactoryClass != null) { + Constructor constructor = socketFactoryClass.getConstructor(); + socketFactory = constructor.newInstance(); + return socketFactory.createSocket(); + } + } catch (Exception exp) { + throw new IOException("Socket factory failed to initialized with option \"socketFactory\" set to \"" + + urlParser.getOptions().socketFactory + "\"", exp); + } + } + socketFactory = SocketFactory.getDefault(); + return socketFactory.createSocket(); + } + /** * Escape String. * @@ -542,42 +574,8 @@ public static TimeZone getTimeZone(String id) throws SQLException { * @return a nex socket * @throws IOException if connection error occur */ - @SuppressWarnings("unchecked") public static Socket createSocket(UrlParser urlParser, String host) throws IOException { - - if (urlParser.getOptions().pipe != null) { - return new NamedPipeSocket(host, urlParser.getOptions().pipe); - } else if (urlParser.getOptions().localSocket != null) { - try { - return new UnixDomainSocket(urlParser.getOptions().localSocket); - } catch (RuntimeException re) { - throw new IOException(re.getMessage(), re.getCause()); - } - } else if (urlParser.getOptions().sharedMemory != null) { - try { - return new SharedMemorySocket(urlParser.getOptions().sharedMemory); - } catch (RuntimeException re) { - throw new IOException(re.getMessage(), re.getCause()); - } - } else { - SocketFactory socketFactory; - String socketFactoryName = urlParser.getOptions().socketFactory; - if (socketFactoryName != null) { - try { - Class socketFactoryClass = (Class) Class.forName(socketFactoryName); - if (socketFactoryClass != null) { - Constructor constructor = socketFactoryClass.getConstructor(); - socketFactory = constructor.newInstance(); - return socketFactory.createSocket(); - } - } catch (Exception exp) { - throw new IOException("Socket factory failed to initialized with option \"socketFactory\" set to \"" - + urlParser.getOptions().socketFactory + "\"", exp); - } - } - socketFactory = SocketFactory.getDefault(); - return socketFactory.createSocket(); - } + return (Socket) socketHandler.apply(urlParser, host); } /**