Skip to content

Commit

Permalink
WebSocket Extensions (#1934)
Browse files Browse the repository at this point in the history
* Addresses #1607
  • Loading branch information
dansiviter committed Sep 14, 2020
1 parent b08b862 commit 701a01c
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 8 deletions.
2 changes: 1 addition & 1 deletion dependencies/pom.xml
Expand Up @@ -110,7 +110,7 @@
<version.lib.snakeyaml>1.24</version.lib.snakeyaml>
<version.lib.transaction-api>1.3.3</version.lib.transaction-api>
<version.lib.typesafe-config>1.4.0</version.lib.typesafe-config>
<version.lib.tyrus>1.15</version.lib.tyrus>
<version.lib.tyrus>1.17</version.lib.tyrus>
<version.lib.ucp>${version.lib.ojdbc8}</version.lib.ucp>
<version.lib.validation-api>2.0.2</version.lib.validation-api>
<version.lib.websockets-api>1.1.2</version.lib.websockets-api>
Expand Down
Expand Up @@ -22,6 +22,7 @@
import java.util.logging.Logger;

import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.server.ServerApplicationConfig;

/**
Expand All @@ -32,11 +33,13 @@ public final class WebSocketApplication {
private Class<? extends ServerApplicationConfig> applicationClass;
private Set<Class<?>> annotatedEndpoints;
private Set<Class<? extends Endpoint>> programmaticEndpoints;
private Set<Extension> extensions;

private WebSocketApplication(Builder builder) {
this.applicationClass = builder.applicationClass;
this.annotatedEndpoints = builder.annotatedEndpoints;
this.programmaticEndpoints = builder.programmaticEndpoints;
this.extensions = builder.extensions;
}

/**
Expand Down Expand Up @@ -75,6 +78,15 @@ public Set<Class<?>> annotatedEndpoints() {
return annotatedEndpoints;
}

/**
* Get list of installed extensions.
*
* @return List of installed extensions.
*/
public Set<Extension> extensions() {
return extensions;
}

/**
* Fluent API builder to create {@link WebSocketApplication} instances.
*/
Expand All @@ -84,6 +96,7 @@ public static class Builder {
private Class<? extends ServerApplicationConfig> applicationClass;
private Set<Class<?>> annotatedEndpoints = new HashSet<>();
private Set<Class<? extends Endpoint>> programmaticEndpoints = new HashSet<>();
private Set<Extension> extensions = new HashSet<>();

/**
* Updates an application class in the builder. Clears all results from scanning.
Expand Down Expand Up @@ -135,6 +148,17 @@ public Builder annotatedEndpoint(Class<?> annotatedEndpoint) {
return this;
}

/**
* Add single extension.
*
* @param extension Extension.
* @return The builder.
*/
public Builder extension(Extension extension) {
extensions.add(extension);
return this;
}

/**
* Builds application.
*
Expand Down
Expand Up @@ -100,7 +100,7 @@ private void endpointClasses(@Observes @WithAnnotations(ServerEndpoint.class) Pr
}

/**
* Collects programmatic endpoints .
* Collects programmatic endpoints.
*
* @param endpoint The endpoint.
*/
Expand All @@ -109,6 +109,26 @@ private void endpointConfig(@Observes ProcessAnnotatedType<? extends Endpoint> e
appBuilder.programmaticEndpoint(endpoint.getAnnotatedType().getJavaClass());
}

/**
* Collects extensions.
*
* @param extension The extension.
*/
private void extension(@Observes ProcessAnnotatedType<? extends javax.websocket.Extension> extension) {
LOGGER.finest(() -> "Extension found " + extension.getAnnotatedType().getJavaClass());

Class<? extends javax.websocket.Extension> cls = extension.getAnnotatedType().getJavaClass();
try {
javax.websocket.Extension instance = cls.getConstructor().newInstance();
appBuilder.extension(instance);
} catch (NoSuchMethodException e) {
LOGGER.warning(() -> "Extension does not have no-args constructor for "
+ extension.getAnnotatedType().getJavaClass() + "! Skppping.");
} catch (ReflectiveOperationException e) {
throw new IllegalStateException("Unable to load WebSocket extension", e);
}
}

/**
* Provides access to websocket application.
*
Expand Down Expand Up @@ -165,6 +185,7 @@ private void registerWebSockets() {
// Direct registration without calling application class
app.annotatedEndpoints().forEach(builder::register);
app.programmaticEndpoints().forEach(builder::register);
app.extensions().forEach(builder::register);

// Create routing builder
routing = serverCdiExtension.serverRoutingBuilder();
Expand Down
Expand Up @@ -20,10 +20,13 @@
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import java.io.IOException;
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand All @@ -46,14 +49,20 @@ class EchoClient {

private final URI uri;
private final BiFunction<String, String, Boolean> equals;
private final List<Extension> extensions;

public EchoClient(URI uri) {
this(uri, String::equals);
}

public EchoClient(URI uri, BiFunction<String, String, Boolean> equals) {
public EchoClient(URI uri, Extension... extensions) {
this(uri, String::equals, extensions);
}

public EchoClient(URI uri, BiFunction<String, String, Boolean> equals, Extension... extensions) {
this.uri = uri;
this.equals = equals;
this.extensions = Arrays.asList(extensions);
}

/**
Expand All @@ -66,7 +75,7 @@ public void echo(String... messages) throws Exception {
CountDownLatch messageLatch = new CountDownLatch(messages.length);
CompletableFuture<Void> openFuture = new CompletableFuture<>();
CompletableFuture<Void> closeFuture = new CompletableFuture<>();
ClientEndpointConfig config = ClientEndpointConfig.Builder.create().build();
ClientEndpointConfig config = ClientEndpointConfig.Builder.create().extensions(extensions).build();

client.connectToServer(new Endpoint() {
@Override
Expand Down
Expand Up @@ -22,14 +22,12 @@

import javax.enterprise.context.Dependent;
import javax.enterprise.inject.se.SeContainerInitializer;
import javax.enterprise.inject.spi.CDI;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerApplicationConfig;
import javax.websocket.server.ServerEndpointConfig;

import io.helidon.microprofile.server.RoutingPath;

import io.helidon.microprofile.server.ServerCdiExtension;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

Expand Down
@@ -0,0 +1,94 @@
/*
* Copyright (c) 2020 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.helidon.microprofile.tyrus;

import java.net.URI;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;

import javax.enterprise.inject.se.SeContainer;
import javax.enterprise.inject.se.SeContainerInitializer;
import javax.enterprise.inject.spi.CDI;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import io.helidon.microprofile.server.ServerCdiExtension;
import javax.websocket.Extension;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;

/**
* A test that mixes Websocket endpoints and extensions in the same application.
*/
public class WebSocketExtensionEndpointTest {

static SeContainer container;

@BeforeAll
static void initClass() {
container = SeContainerInitializer.newInstance()
.addBeanClasses(ExtensionEndpointAnnot.class, TestExtension.class)
.initialize();
}

@AfterAll
static void destroyClass() {
container.close();
}

public int port() {
ServerCdiExtension cdiExtension = CDI.current().getBeanManager().getExtension(ServerCdiExtension.class);
return cdiExtension.port();
}

@Test
public void test() throws Exception {
URI echoUri = URI.create("ws://localhost:" + port() + "/extAnnot");
EchoClient echoClient = new EchoClient(echoUri, new TestExtension());
echoClient.echo("hi", "how are you?");
}

@ServerEndpoint("/extAnnot")
public static class ExtensionEndpointAnnot {
private static final Logger LOGGER = Logger.getLogger(ExtensionEndpointAnnot.class.getName());

@OnMessage
public void echo(Session session, String message) throws Exception {
LOGGER.info("OnMessage called '" + message + "'");
if (session.getNegotiatedExtensions().isEmpty()) {
throw new IllegalStateException();
}
session.getBasicRemote().sendObject(message);
}
}

public static class TestExtension implements Extension {
@Override
public String getName() {
return "testExtension";
}

@Override
public List<Parameter> getParameters() {
return Collections.emptyList();
}
}
}
Expand Up @@ -27,6 +27,7 @@
import java.util.logging.Logger;

import javax.websocket.DeploymentException;
import javax.websocket.Extension;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;

Expand Down Expand Up @@ -60,6 +61,7 @@ public class TyrusSupport implements Service {
private final TyrusHandler handler = new TyrusHandler();
private Set<Class<?>> endpointClasses;
private Set<ServerEndpointConfig> endpointConfigs;
private Set<Extension> extensions;

/**
* Create from another instance.
Expand All @@ -70,12 +72,18 @@ protected TyrusSupport(TyrusSupport other) {
this.engine = other.engine;
this.endpointClasses = other.endpointClasses;
this.endpointConfigs = other.endpointConfigs;
this.extensions = other.extensions;
}

TyrusSupport(WebSocketEngine engine, Set<Class<?>> endpointClasses, Set<ServerEndpointConfig> endpointConfigs) {
TyrusSupport(
WebSocketEngine engine,
Set<Class<?>> endpointClasses,
Set<ServerEndpointConfig> endpointConfigs,
Set<Extension> extensions) {
this.engine = engine;
this.endpointClasses = endpointClasses;
this.endpointConfigs = endpointConfigs;
this.extensions = extensions;
}

/**
Expand Down Expand Up @@ -108,6 +116,15 @@ public Set<ServerEndpointConfig> endpointConfigs() {
return Collections.unmodifiableSet(endpointConfigs);
}

/**
* Access to extensions.
*
* @return Immutable set of extensions.
*/
public Set<Extension> extensions() {
return Collections.unmodifiableSet(extensions);
}

/**
* Returns executor service, can be overridden.
*
Expand All @@ -133,6 +150,7 @@ public static class Builder implements io.helidon.common.Builder<TyrusSupport> {

private Set<Class<?>> endpointClasses = new HashSet<>();
private Set<ServerEndpointConfig> endpointConfigs = new HashSet<>();
private Set<Extension> extensions = new HashSet<>();

private Builder() {
}
Expand All @@ -159,8 +177,21 @@ public Builder register(ServerEndpointConfig endpointConfig) {
return this;
}

/**
* Register an extension.
*
* @param extension The extension.
* @return The builder.
*/
public Builder register(Extension extension) {
extensions.add(extension);
return this;
}

@Override
public TyrusSupport build() {
// a purposefully mutable extensions
Set<Extension> installedExtensions = new HashSet<>(extensions);
// Create container and WebSocket engine
TyrusServerContainer serverContainer = new TyrusServerContainer(endpointClasses) {
private final WebSocketEngine engine =
Expand All @@ -176,6 +207,11 @@ public void register(ServerEndpointConfig serverEndpointConfig) {
throw new UnsupportedOperationException("Use TyrusWebSocketEngine for registration");
}

@Override
public Set<Extension> getInstalledExtensions() {
return installedExtensions;
}

@Override
public WebSocketEngine getWebSocketEngine() {
return engine;
Expand All @@ -202,7 +238,7 @@ public WebSocketEngine getWebSocketEngine() {
});

// Create TyrusSupport using WebSocket engine
return new TyrusSupport(serverContainer.getWebSocketEngine(), endpointClasses, endpointConfigs);
return new TyrusSupport(serverContainer.getWebSocketEngine(), endpointClasses, endpointConfigs, extensions);
}
}

Expand Down

0 comments on commit 701a01c

Please sign in to comment.