diff --git a/src/main/java/org/jboss/remoting3/Connection.java b/src/main/java/org/jboss/remoting3/Connection.java index 2dbe9f2e8..e9508c844 100644 --- a/src/main/java/org/jboss/remoting3/Connection.java +++ b/src/main/java/org/jboss/remoting3/Connection.java @@ -121,6 +121,13 @@ default S getPeerAddress(Class type) { */ URI getPeerURI(); + /** + * Get the protocol of this connection. + * + * @return the protocol (not {@code null}) + */ + String getProtocol(); + /** * Get the local identity of this inbound connection. * diff --git a/src/main/java/org/jboss/remoting3/ConnectionImpl.java b/src/main/java/org/jboss/remoting3/ConnectionImpl.java index 5e481582a..5227e1ce2 100644 --- a/src/main/java/org/jboss/remoting3/ConnectionImpl.java +++ b/src/main/java/org/jboss/remoting3/ConnectionImpl.java @@ -63,11 +63,13 @@ class ConnectionImpl extends AbstractHandleableCloseable implements private final IntIndexHashMap authMap = new IntIndexHashMap(Auth::getId); private final SaslAuthenticationFactory authenticationFactory; private final AuthenticationConfiguration authenticationConfiguration; + private final String protocol; ConnectionImpl(final EndpointImpl endpoint, final ConnectionHandlerFactory connectionHandlerFactory, final ConnectionProviderContext connectionProviderContext, final URI peerUri, final Principal principal, final UnaryOperator saslClientFactoryOperator, final SaslAuthenticationFactory authenticationFactory, final AuthenticationConfiguration authenticationConfiguration) { super(endpoint.getExecutor(), true); this.endpoint = endpoint; this.peerUri = peerUri; + this.protocol = connectionProviderContext.getProtocol(); this.principal = principal; this.authenticationConfiguration = authenticationConfiguration; this.connectionHandler = connectionHandlerFactory.createInstance(endpoint.new LocalConnectionContext(connectionProviderContext, this)); @@ -124,6 +126,10 @@ public URI getPeerURI() { return peerUri; } + public String getProtocol() { + return protocol; + } + public SecurityIdentity getLocalIdentity() { return connectionHandler.getLocalIdentity(); } diff --git a/src/main/java/org/jboss/remoting3/EndpointImpl.java b/src/main/java/org/jboss/remoting3/EndpointImpl.java index 6bdf1f3ba..47eae7a2c 100644 --- a/src/main/java/org/jboss/remoting3/EndpointImpl.java +++ b/src/main/java/org/jboss/remoting3/EndpointImpl.java @@ -23,6 +23,7 @@ package org.jboss.remoting3; import static java.security.AccessController.doPrivileged; +import static org.xnio.IoUtils.safeClose; import java.io.IOException; import java.net.InetSocketAddress; @@ -73,7 +74,6 @@ import org.wildfly.security.auth.server.SaslAuthenticationFactory; import org.wildfly.security.sasl.util.PrivilegedSaslClientFactory; import org.wildfly.security.sasl.util.ProtocolSaslClientFactory; -import org.wildfly.security.sasl.util.SaslFactories; import org.wildfly.security.sasl.util.ServerNameSaslClientFactory; import org.xnio.Bits; @@ -106,7 +106,7 @@ final class EndpointImpl extends AbstractHandleableCloseable implement private final Attachments attachments = new Attachments(); - private final ConcurrentMap connectionProviders = new ConcurrentHashMap<>(); + private final ConcurrentMap connectionProviders = new ConcurrentHashMap<>(); private final ConcurrentMap registeredServices = new ConcurrentHashMap<>(); private final ConcurrentMap configuredConnections = new ConcurrentHashMap<>(); @@ -128,7 +128,6 @@ final class EndpointImpl extends AbstractHandleableCloseable implement * The name of this endpoint. */ private final String name; - private final ConnectionProviderContext connectionProviderContext; private final CloseHandler resourceCloseHandler = (closed, exception) -> closeTick1(closed); private final CloseHandler connectionCloseHandler = (closed, exception) -> connections.remove(closed); private final boolean ourWorker; @@ -141,8 +140,6 @@ private EndpointImpl(final XnioWorker xnioWorker, final boolean ourWorker, final this.xnio = xnioWorker.getXnio(); this.name = name; this.defaultBindAddress = defaultBindAddress; - // initialize CPC - connectionProviderContext = new ConnectionProviderContextImpl(); // get XNIO worker log.tracef("Completed open of %s", this); } @@ -358,8 +355,8 @@ protected void closeAction() throws IOException { for (Object connection : connections.toArray()) { ((ConnectionImpl)connection).closeAsync(); } - for (ConnectionProvider connectionProvider : connectionProviders.values()) { - connectionProvider.closeAsync(); + for (ProtocolRegistration protocolRegistration : connectionProviders.values()) { + protocolRegistration.getProvider().closeAsync(); } } } @@ -457,10 +454,11 @@ IoFuture connect(final URI destination, final SocketAddress bindAddr boolean ok = false; resourceUntick("Connection to " + destination); try { - final ConnectionProvider connectionProvider = connectionProviders.get(scheme); - if (connectionProvider == null) { + final ProtocolRegistration protocolRegistration = connectionProviders.get(scheme); + if (protocolRegistration == null) { throw new UnknownURISchemeException("No connection provider for URI scheme \"" + scheme + "\" is installed"); } + final ConnectionProvider connectionProvider = protocolRegistration.getProvider(); final FutureResult futureResult = new FutureResult(getExecutor()); // Mark the stack because otherwise debugging connect problems can be incredibly tough final StackTraceElement[] mark = Thread.currentThread().getStackTrace(); @@ -495,7 +493,7 @@ public boolean setResult(final ConnectionHandlerFactory connHandlerFactory) { } synchronized (connectionLock) { log.logf(getClass().getName(), Logger.Level.TRACE, null, "Registered successful result %s", connHandlerFactory); - final ConnectionImpl connection = new ConnectionImpl(EndpointImpl.this, connHandlerFactory, connectionProviderContext, destination, principal, finalFactoryOperator, null, configuration); + final ConnectionImpl connection = new ConnectionImpl(EndpointImpl.this, connHandlerFactory, protocolRegistration.getContext(), destination, principal, finalFactoryOperator, null, configuration); connections.add(connection); connection.getConnectionHandler().addCloseHandler(SpiUtils.asyncClosingCloseHandler(connection)); connection.addCloseHandler(resourceCloseHandler); @@ -535,15 +533,17 @@ public Registration addConnectionProvider(final String uriScheme, final Connecti boolean ok = false; resourceUntick("Connection provider for " + uriScheme); try { - final ConnectionProviderContextImpl context = new ConnectionProviderContextImpl(); + final ConnectionProviderContextImpl context = new ConnectionProviderContextImpl(uriScheme); final ConnectionProvider provider = providerFactory.createInstance(context, optionMap); + final ProtocolRegistration protocolRegistration = new ProtocolRegistration(provider, context); try { - if (connectionProviders.putIfAbsent(uriScheme, provider) != null) { + if (connectionProviders.putIfAbsent(uriScheme, protocolRegistration) != null) { + safeClose(provider); throw new DuplicateRegistrationException("URI scheme '" + uriScheme + "' is already registered to a provider"); } // add a resource count for close log.tracef("Adding connection provider registration named '%s': %s", uriScheme, provider); - final Registration registration = new MapRegistration(connectionProviders, uriScheme, provider) { + final Registration registration = new MapRegistration(connectionProviders, uriScheme, protocolRegistration) { protected void closeAction() throws IOException { try { provider.closeAsync(); @@ -570,6 +570,24 @@ protected void closeAction() throws IOException { } } + static final class ProtocolRegistration { + private final ConnectionProvider provider; + private final ConnectionProviderContextImpl context; + + ProtocolRegistration(final ConnectionProvider provider, final ConnectionProviderContextImpl context) { + this.provider = provider; + this.context = context; + } + + ConnectionProvider getProvider() { + return provider; + } + + ConnectionProviderContextImpl getContext() { + return context; + } + } + public T getConnectionProviderInterface(final String uriScheme, final Class expectedType) throws UnknownURISchemeException, ClassCastException { final SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -578,11 +596,11 @@ public T getConnectionProviderInterface(final String uriScheme, final Class< if (! expectedType.isInterface()) { throw new IllegalArgumentException("Interface expected"); } - final ConnectionProvider provider = connectionProviders.get(uriScheme); - if (provider == null) { + final ProtocolRegistration protocolRegistration = connectionProviders.get(uriScheme); + if (protocolRegistration == null) { throw new UnknownURISchemeException("No connection provider for URI scheme \"" + uriScheme + "\" is installed"); } - return expectedType.cast(provider.getProviderInterface()); + return expectedType.cast(protocolRegistration.getProvider().getProviderInterface()); } public boolean isValidUriScheme(final String uriScheme) { @@ -605,13 +623,13 @@ public String toString() { return b.toString(); } - private class MapRegistration extends AbstractHandleableCloseable implements Registration { + class MapRegistration extends AbstractHandleableCloseable implements Registration { private final ConcurrentMap map; private final String key; private final T value; - private MapRegistration(final ConcurrentMap map, final String key, final T value) { + MapRegistration(final ConcurrentMap map, final String key, final T value) { super(EndpointImpl.this.getExecutor(), false); this.map = map; this.key = key; @@ -631,6 +649,10 @@ public void close() { } } + T getValue() { + return value; + } + public String toString() { return String.format("Registration of '%s': %s", key, value); } @@ -711,9 +733,12 @@ public void receiveAuthDeleteAck(final int id) { } } - private final class ConnectionProviderContextImpl implements ConnectionProviderContext { + final class ConnectionProviderContextImpl implements ConnectionProviderContext { - private ConnectionProviderContextImpl() { + private final String protocol; + + ConnectionProviderContextImpl(final String protocol) { + this.protocol = protocol; } public void accept(final ConnectionHandlerFactory connectionHandlerFactory, final SaslAuthenticationFactory authenticationFactory) { @@ -743,7 +768,7 @@ public Endpoint getEndpoint() { } public Xnio getXnio() { - return xnio; + return worker.getXnio(); } public Executor getExecutor() { @@ -753,9 +778,13 @@ public Executor getExecutor() { public XnioWorker getXnioWorker() { return worker; } + + public String getProtocol() { + return protocol; + } } - private static class RegisteredServiceImpl implements RegisteredService { + static class RegisteredServiceImpl implements RegisteredService { private final OpenListener openListener; private final OptionMap optionMap; diff --git a/src/main/java/org/jboss/remoting3/spi/ConnectionProviderContext.java b/src/main/java/org/jboss/remoting3/spi/ConnectionProviderContext.java index dfeffc907..ffee95adb 100644 --- a/src/main/java/org/jboss/remoting3/spi/ConnectionProviderContext.java +++ b/src/main/java/org/jboss/remoting3/spi/ConnectionProviderContext.java @@ -69,4 +69,11 @@ public interface ConnectionProviderContext { * @return the XNIO worker */ XnioWorker getXnioWorker(); + + /** + * Get the protocol of this connection provider. + * + * @return the protocol of this connection provider (not {@code null}) + */ + String getProtocol(); }