Skip to content

Commit

Permalink
[CONJ-634] avoid import of non dependency to permit compilation Ahead…
Browse files Browse the repository at this point in the history
… of time
  • Loading branch information
rusher committed Sep 3, 2018
1 parent a4fe8a6 commit fb81de3
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 101 deletions.
Expand Up @@ -52,24 +52,36 @@

package org.mariadb.jdbc.internal.com.send;

import com.sun.jna.Platform;
import org.mariadb.jdbc.internal.com.read.Buffer;
import org.mariadb.jdbc.internal.com.read.ErrorPacket;
import org.mariadb.jdbc.internal.com.send.gssapi.GssUtility;
import org.mariadb.jdbc.internal.com.send.gssapi.GssapiAuth;
import org.mariadb.jdbc.internal.com.send.gssapi.StandardGssapiAuthentication;
import org.mariadb.jdbc.internal.com.send.gssapi.WindowsNativeSspiAuthentication;
import org.mariadb.jdbc.internal.io.input.PacketInputStream;
import org.mariadb.jdbc.internal.io.output.PacketOutputStream;

import java.io.EOFException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.function.BiFunction;

import static org.mariadb.jdbc.internal.com.Packet.ERROR;

public class SendGssApiAuthPacket extends AbstractAuthSwitchSendResponsePacket implements InterfaceAuthSwitchSendResponsePacket {
private final PacketInputStream reader;
private final static BiFunction<PacketInputStream, Integer, GssapiAuth> gssMethod;

static {
BiFunction<PacketInputStream, Integer, GssapiAuth> init;
try {
init = GssUtility.getAuthenticationMethod();
} catch (Throwable t) {
BiFunction<PacketInputStream, Integer, GssapiAuth> defaultAuthenticationMethod = (reader, packSeq) -> new StandardGssapiAuthentication(reader, packSeq);
init = defaultAuthenticationMethod;
}
gssMethod = init;
}

public SendGssApiAuthPacket(PacketInputStream reader, String password, byte[] authData, int packSeq, String passwordCharacterEncoding) {
super(packSeq, authData, password, passwordCharacterEncoding);
Expand All @@ -88,7 +100,7 @@ public void send(PacketOutputStream pos) throws IOException, SQLException {
String mechanisms = buffer.readStringNullEnd(StandardCharsets.UTF_8);
if (mechanisms.isEmpty()) mechanisms = "Kerberos";

GssapiAuth gssapiAuth = getAuthenticationMethod();
GssapiAuth gssapiAuth = gssMethod.apply(reader, packSeq);
gssapiAuth.authenticate(pos, serverPrincipalName, mechanisms);
}

Expand All @@ -107,28 +119,5 @@ public void handleResultPacket(PacketInputStream reader) throws SQLException, IO
}
}

/**
* Get authentication method according to classpath.
* Windows native authentication is using Waffle-jna.
*
* @return authentication method
*/
private GssapiAuth getAuthenticationMethod() {
try {
//Waffle-jna has jna as dependency, so if not available on classpath, just use standard authentication
if (Platform.isWindows()) {
try {
Class.forName("waffle.windows.auth.impl.WindowsAuthProviderImpl");
return new WindowsNativeSspiAuthentication(reader, packSeq);
} catch (ClassNotFoundException cle) {
//waffle not in the classpath
}
}
} catch (Throwable cle) {
//jna jar's are not in classpath
}
return new StandardGssapiAuthentication(reader, packSeq);
}

}

Expand Up @@ -61,6 +61,7 @@
import org.mariadb.jdbc.internal.protocol.authentication.DefaultAuthenticationProvider;
import org.mariadb.jdbc.internal.util.Options;
import org.mariadb.jdbc.internal.util.PidFactory;
import org.mariadb.jdbc.internal.util.PidRequestInter;
import org.mariadb.jdbc.internal.util.Utils;
import org.mariadb.jdbc.internal.util.constant.Version;

Expand Down Expand Up @@ -92,6 +93,18 @@
*/
public class SendHandshakeResponsePacket {

private static final PidRequestInter pidRequest;

static {
PidRequestInter init;
try {
init = PidFactory.getInstance();
} catch (Throwable t) {
init = () -> null;
}
pidRequest = init;
}

/**
* Send handshake response packet.
*
Expand Down Expand Up @@ -214,8 +227,7 @@ private static void writeConnectAttributes(PacketOutputStream pos, String connec

buffer.writeStringSmallLength(_OS);
buffer.writeStringLength(System.getProperty("os.name"));

String pid = PidFactory.getInstance().getPid();
String pid = pidRequest.getPid();
if (pid != null) {
buffer.writeStringSmallLength(_PID);
buffer.writeStringLength(pid);
Expand Down
@@ -0,0 +1,33 @@
package org.mariadb.jdbc.internal.com.send.gssapi;

import com.sun.jna.Platform;
import org.mariadb.jdbc.internal.io.input.PacketInputStream;

import java.util.function.BiFunction;

public class GssUtility {

/**
* Get authentication method according to classpath.
* Windows native authentication is using Waffle-jna.
*
* @return authentication method
*/
public static BiFunction<PacketInputStream, Integer, GssapiAuth> getAuthenticationMethod() {
try {
//Waffle-jna has jna as dependency, so if not available on classpath, just use standard authentication
if (Platform.isWindows()) {
try {
Class.forName("waffle.windows.auth.impl.WindowsAuthProviderImpl");
return (reader, packSeq) -> new WindowsNativeSspiAuthentication(reader, packSeq);
} catch (ClassNotFoundException cle) {
//waffle not in the classpath
}
}
} catch (Throwable cle) {
//jna jar's are not in classpath
}
return (reader, packSeq) -> new StandardGssapiAuthentication(reader, packSeq);
}

}
@@ -0,0 +1,11 @@
package org.mariadb.jdbc.internal.io.socket;

import org.mariadb.jdbc.UrlParser;

import java.io.IOException;
import java.net.Socket;

@FunctionalInterface
public interface SocketHandlerFunction {
Socket apply(UrlParser a, String b) throws IOException;
}
@@ -0,0 +1,41 @@
package org.mariadb.jdbc.internal.io.socket;

import com.sun.jna.Platform;
import org.mariadb.jdbc.internal.util.Utils;

import java.io.IOException;

public class SocketUtility {

@SuppressWarnings("unchecked")
public static SocketHandlerFunction getSocketHandler() {
try {
//forcing use of JNA to ensure AOT compilation
Platform.getOSType();

return (urlParser, host) -> {
if (urlParser.getOptions().pipe != null) {
return new NamedPipeSocket(host, urlParser.getOptions().pipe);
} else if (urlParser.getOptions().localSocket != null) {
try {
return new UnixDomainSocket(urlParser.getOptions().localSocket);
} catch (RuntimeException re) {
throw new IOException(re.getMessage(), re.getCause());
}
} else if (urlParser.getOptions().sharedMemory != null) {
try {
return new SharedMemorySocket(urlParser.getOptions().sharedMemory);
} catch (RuntimeException re) {
throw new IOException(re.getMessage(), re.getCause());
}
} else {
return Utils.standardSocket(urlParser, host);
}

};
} catch (Throwable cle) {
//jna jar's are not in classpath
}
return (urlParser, host) -> Utils.standardSocket(urlParser, host);
}
}
54 changes: 19 additions & 35 deletions src/main/java/org/mariadb/jdbc/internal/util/PidFactory.java
Expand Up @@ -59,55 +59,39 @@

public class PidFactory {

private static PidRequestInter pidRequest = null;

/**
* Factory method to avoid loading JNA classes every connection.
*
* @return factory that implement PID according to environment.
*/
public static PidRequestInter getInstance() {
if (pidRequest == null) {
synchronized (PidFactory.class) {
// check again within synchronized block to guard for race condition
if (pidRequest == null) {
//initialize JNA methods
try {
if (Platform.isLinux()) {
//Linux pid implementation
pidRequest = () -> String.valueOf(CLibrary.INSTANCE.getpid());
} else {
if (Platform.isWindows()) {
//Windows pid implementation
pidRequest = () -> {
try {
return String.valueOf(Kernel32.INSTANCE.GetCurrentProcessId());
} catch (Throwable cle) {
//jna plateform jar's are not in classpath, no PID returned
}
return null;
};

}
try {
if (Platform.isLinux()) {
//Linux pid implementation
return () -> String.valueOf(CLibrary.INSTANCE.getpid());
} else {
if (Platform.isWindows()) {
//Windows pid implementation
return () -> {
try {
return String.valueOf(Kernel32.INSTANCE.GetCurrentProcessId());
} catch (Throwable cle) {
//jna plateform jar's are not in classpath, no PID returned
}
} catch (Throwable cle) {
//jna jar's are not in classpath, no PID returned
}
return null;
};

//No JNA, or environment not Linux/windows -> return no PID
if (pidRequest == null) {
pidRequest = () -> null;
}
}
}
} catch (Throwable cle) {
//jna jar's are not in classpath, no PID returned
}
return pidRequest;
}

public interface PidRequestInter {
String getPid();
//No JNA, or environment not Linux/windows -> return no PID
return () -> null;
}


private interface CLibrary extends Library {
CLibrary INSTANCE = (CLibrary) Native.loadLibrary("c", CLibrary.class);
int getpid();
Expand Down
@@ -0,0 +1,5 @@
package org.mariadb.jdbc.internal.util;

public interface PidRequestInter {
String getPid();
}
74 changes: 36 additions & 38 deletions src/main/java/org/mariadb/jdbc/internal/util/Utils.java
Expand Up @@ -57,9 +57,7 @@
import org.mariadb.jdbc.internal.failover.impl.AuroraListener;
import org.mariadb.jdbc.internal.failover.impl.MastersFailoverListener;
import org.mariadb.jdbc.internal.failover.impl.MastersSlavesListener;
import org.mariadb.jdbc.internal.io.socket.NamedPipeSocket;
import org.mariadb.jdbc.internal.io.socket.SharedMemorySocket;
import org.mariadb.jdbc.internal.io.socket.UnixDomainSocket;
import org.mariadb.jdbc.internal.io.socket.*;
import org.mariadb.jdbc.internal.logging.ProtocolLoggingProxy;
import org.mariadb.jdbc.internal.protocol.AuroraProtocol;
import org.mariadb.jdbc.internal.protocol.MasterProtocol;
Expand Down Expand Up @@ -93,6 +91,40 @@ public class Utils {
private static final Pattern IP_V6_COMPRESSED = Pattern.compile("^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)"
+ "::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)$");

private final static SocketHandlerFunction socketHandler;

static {
SocketHandlerFunction init;
try {
init = SocketUtility.getSocketHandler();
} catch (Throwable t) {
SocketHandlerFunction defaultSocketHandler = (urlParser, host) -> Utils.standardSocket(urlParser, host);
init = defaultSocketHandler;
}
socketHandler = init;
}

public static Socket standardSocket(UrlParser urlParser, String host) throws IOException {
SocketFactory socketFactory;
String socketFactoryName = urlParser.getOptions().socketFactory;
if (socketFactoryName != null) {
try {
@SuppressWarnings("unchecked")
Class<? extends SocketFactory> socketFactoryClass = (Class<? extends SocketFactory>) Class.forName(socketFactoryName);
if (socketFactoryClass != null) {
Constructor<? extends SocketFactory> constructor = socketFactoryClass.getConstructor();
socketFactory = constructor.newInstance();
return socketFactory.createSocket();
}
} catch (Exception exp) {
throw new IOException("Socket factory failed to initialized with option \"socketFactory\" set to \""
+ urlParser.getOptions().socketFactory + "\"", exp);
}
}
socketFactory = SocketFactory.getDefault();
return socketFactory.createSocket();
}

/**
* Escape String.
*
Expand Down Expand Up @@ -542,42 +574,8 @@ public static TimeZone getTimeZone(String id) throws SQLException {
* @return a nex socket
* @throws IOException if connection error occur
*/
@SuppressWarnings("unchecked")
public static Socket createSocket(UrlParser urlParser, String host) throws IOException {

if (urlParser.getOptions().pipe != null) {
return new NamedPipeSocket(host, urlParser.getOptions().pipe);
} else if (urlParser.getOptions().localSocket != null) {
try {
return new UnixDomainSocket(urlParser.getOptions().localSocket);
} catch (RuntimeException re) {
throw new IOException(re.getMessage(), re.getCause());
}
} else if (urlParser.getOptions().sharedMemory != null) {
try {
return new SharedMemorySocket(urlParser.getOptions().sharedMemory);
} catch (RuntimeException re) {
throw new IOException(re.getMessage(), re.getCause());
}
} else {
SocketFactory socketFactory;
String socketFactoryName = urlParser.getOptions().socketFactory;
if (socketFactoryName != null) {
try {
Class<? extends SocketFactory> socketFactoryClass = (Class<? extends SocketFactory>) Class.forName(socketFactoryName);
if (socketFactoryClass != null) {
Constructor<? extends SocketFactory> constructor = socketFactoryClass.getConstructor();
socketFactory = constructor.newInstance();
return socketFactory.createSocket();
}
} catch (Exception exp) {
throw new IOException("Socket factory failed to initialized with option \"socketFactory\" set to \""
+ urlParser.getOptions().socketFactory + "\"", exp);
}
}
socketFactory = SocketFactory.getDefault();
return socketFactory.createSocket();
}
return (Socket) socketHandler.apply(urlParser, host);
}

/**
Expand Down

0 comments on commit fb81de3

Please sign in to comment.