diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java index 6234e95fb3c..f664a37aeef 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java @@ -17,6 +17,7 @@ import io.micronaut.context.event.ApplicationEventPublisher; import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.NonNull; import io.micronaut.core.annotation.Nullable; import io.micronaut.core.async.publisher.Publishers; import io.micronaut.core.bind.BoundExecutable; @@ -26,6 +27,7 @@ import io.micronaut.core.propagation.PropagatedContext; import io.micronaut.core.type.Argument; import io.micronaut.core.type.Executable; +import io.micronaut.core.type.ReturnType; import io.micronaut.core.util.KotlinUtils; import io.micronaut.http.HttpAttributes; import io.micronaut.http.HttpRequest; @@ -38,6 +40,8 @@ import io.micronaut.http.server.netty.NettyEmbeddedServices; import io.micronaut.inject.ExecutableMethod; import io.micronaut.inject.MethodExecutionHandle; +import io.micronaut.scheduling.executor.ExecutorSelector; +import io.micronaut.scheduling.executor.ThreadSelection; import io.micronaut.web.router.UriRouteMatch; import io.micronaut.websocket.CloseReason; import io.micronaut.websocket.WebSocketPongMessage; @@ -65,6 +69,8 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -91,6 +97,8 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { private final Argument bodyArgument; private final Argument pongArgument; + private final ThreadSelection threadSelection; + private final ExecutorSelector executorSelector; /** * Default constructor. @@ -102,17 +110,20 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { * @param request The request used to create the websocket * @param routeMatch The route match * @param ctx The channel handler context + * @param executorSelector * @param coroutineHelper Helper for kotlin coroutines */ NettyServerWebSocketHandler( - NettyEmbeddedServices nettyEmbeddedServices, - WebSocketSessionRepository webSocketSessionRepository, - WebSocketServerHandshaker handshaker, - WebSocketBean webSocketBean, - HttpRequest request, - UriRouteMatch routeMatch, - ChannelHandlerContext ctx, - @Nullable CoroutineHelper coroutineHelper) { + NettyEmbeddedServices nettyEmbeddedServices, + WebSocketSessionRepository webSocketSessionRepository, + WebSocketServerHandshaker handshaker, + WebSocketBean webSocketBean, + HttpRequest request, + UriRouteMatch routeMatch, + ChannelHandlerContext ctx, + ThreadSelection threadSelection, + ExecutorSelector executorSelector, + @Nullable CoroutineHelper coroutineHelper) { super( ctx, nettyEmbeddedServices.getRequestArgumentSatisfier().getBinderRegistry(), @@ -125,6 +136,9 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { webSocketSessionRepository, nettyEmbeddedServices.getApplicationContext().getConversionService()); + this.threadSelection = threadSelection; + this.executorSelector = executorSelector; + this.serverSession = createWebSocketSession(ctx); ExecutableBinder binder = new DefaultExecutableBinder<>(); @@ -345,8 +359,25 @@ protected Object invokeExecutable(BoundExecutable boundExecutable, MethodExecuti } private Object invokeExecutable0(BoundExecutable boundExecutable, MethodExecutionHandle messageHandler) { - return ServerRequestContext.with(originatingRequest, - (Supplier) () -> boundExecutable.invoke(messageHandler.getTarget())); + return this.executorSelector.select(messageHandler.getExecutableMethod(), threadSelection) + .map( + executorService -> { + ReturnType returnType = messageHandler.getExecutableMethod().getReturnType(); + if (returnType.isReactive()) { + return Mono.from((Publisher) boundExecutable.invoke(messageHandler.getTarget())) + .subscribeOn(Schedulers.fromExecutor(executorService)) + .contextWrite(reactorContext -> reactorContext.put(ServerRequestContext.KEY, originatingRequest)); + } else { + return executorService.submit(() -> ServerRequestContext.with(originatingRequest, + (Supplier) () -> boundExecutable.invoke(messageHandler.getTarget()))); + } + } + ).orElseGet(invokeWithContext(boundExecutable, messageHandler)); + } + + private Supplier invokeWithContext(BoundExecutable boundExecutable, MethodExecutionHandle messageHandler) { + return () -> ServerRequestContext.with(originatingRequest, + (Supplier) () -> boundExecutable.invoke(messageHandler.getTarget())); } @Override diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketUpgradeHandler.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketUpgradeHandler.java index 977bc406ace..e488cf62947 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketUpgradeHandler.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketUpgradeHandler.java @@ -199,6 +199,8 @@ private void writeResponse(ChannelHandlerContext ctx, msg, routeMatch, ctx, + serverConfiguration.getThreadSelection(), + routeExecutor.getExecutorSelector(), routeExecutor.getCoroutineHelper().orElse(null)); pipeline.addBefore(ctx.name(), NettyServerWebSocketHandler.ID, webSocketHandler); diff --git a/http-server-netty/src/test/groovy/io/micronaut/websocket/WebsocketExecuteOnSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/websocket/WebsocketExecuteOnSpec.groovy new file mode 100644 index 00000000000..ead8c8f04f3 --- /dev/null +++ b/http-server-netty/src/test/groovy/io/micronaut/websocket/WebsocketExecuteOnSpec.groovy @@ -0,0 +1,214 @@ +package io.micronaut.websocket + +import io.micronaut.context.annotation.Property +import io.micronaut.context.annotation.Requires +import io.micronaut.runtime.server.EmbeddedServer +import io.micronaut.scheduling.LoomSupport +import io.micronaut.scheduling.TaskExecutors +import io.micronaut.scheduling.annotation.ExecuteOn +import io.micronaut.test.extensions.spock.annotation.MicronautTest +import io.micronaut.websocket.annotation.* +import jakarta.inject.Inject +import org.reactivestreams.Publisher +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono +import spock.lang.Specification +import spock.lang.Unroll +import spock.util.concurrent.PollingConditions + +import java.util.concurrent.Future +import java.util.function.Predicate +import java.util.function.Supplier +import java.util.stream.Collectors + +@Property(name = "spec.name", value = "WebsocketExecuteOnSpec") +@MicronautTest +class WebsocketExecuteOnSpec extends Specification { + + static final Logger LOG = LoggerFactory.getLogger(WebsocketExecuteOnSpec.class) + + @Inject + EmbeddedServer embeddedServer + + @Unroll + void "#type websocket server methods can run outside of the event loop with ExecuteOn"() { + given: + WebSocketClient wsClient = embeddedServer.applicationContext.createBean(WebSocketClient.class, embeddedServer.getURL()) + String threadName = (LoomSupport.isSupported() ? "virtual" : TaskExecutors.IO) + "-executor" + String expectedJoined = "joined on thread " + threadName + String expectedEcho = "Hello from thread " + threadName + + expect: + wsClient + + when: + EchoClientWebSocket echoClientWebSocket = Flux.from(wsClient.connect(EchoClientWebSocket, "/echo/${type}")).blockFirst() + + then: + noExceptionThrown() + new PollingConditions().eventually { + echoClientWebSocket.receivedMessages() == [expectedJoined] + } + + when: + echoClientWebSocket.send('Hello') + + then: + new PollingConditions().eventually { + echoClientWebSocket.receivedMessages() == [expectedJoined, expectedEcho] + } + + cleanup: + echoClientWebSocket.close() + + where: + type | _ + "sync" | _ + "reactive" | _ + "async" | _ + } + + @Requires(property = "spec.name", value = "WebsocketExecuteOnSpec") + @ServerWebSocket("/echo/sync") + @ExecuteOn(TaskExecutors.BLOCKING) + static class SynchronousEchoServerWebSocket { + public static final String JOINED = "joined on thread %s" + public static final String DISCONNECTED = "disconnected on thread %s" + public static final String ECHO = "%s from thread %s" + + @Inject + WebSocketBroadcaster broadcaster + + @OnOpen + void onOpen(WebSocketSession session) { + broadcaster.broadcastSync(JOINED.formatted(Thread.currentThread().getName()), isValid(session)) + } + + @OnMessage + void onMessage(String message, WebSocketSession session) { + broadcaster.broadcastSync(ECHO.formatted(message, Thread.currentThread().getName()), isValid(session)) + } + + @OnClose + void onClose(WebSocketSession session) { + broadcaster.broadcastSync(DISCONNECTED.formatted(Thread.currentThread().getName()), isValid(session)) + } + + private static Predicate isValid(WebSocketSession session) { + return { s -> s == session } + } + } + + @Requires(property = "spec.name", value = "WebsocketExecuteOnSpec") + @ServerWebSocket("/echo/reactive") + @ExecuteOn(TaskExecutors.BLOCKING) + static class ReactiveEchoServerWebSocket { + public static final String JOINED = "joined on thread %s" + public static final String DISCONNECTED = "disconnected on thread %s" + public static final String ECHO = " from thread %s" + + @Inject + WebSocketBroadcaster broadcaster + + Supplier formatMessage(String message) { + () -> message.formatted(Thread.currentThread().getName()) + } + + @OnOpen + Publisher onOpen(WebSocketSession session) { + Mono.fromSupplier(formatMessage(JOINED)) + .flatMap(message -> Mono.from(broadcaster.broadcast(message))) + } + + @OnMessage + Publisher onMessage(String message, WebSocketSession session) { + Mono.fromSupplier(formatMessage(message + ECHO)) + .flatMap(m -> Mono.from(broadcaster.broadcast(m))) + } + + @OnClose + Publisher onClose(WebSocketSession session) { + Mono.just(session) + .flatMap(s -> { + LOG.info(DISCONNECTED.formatted(Thread.currentThread().getName())) + return Mono.just("closed") + }) + } + } + + @Requires(property = "spec.name", value = "WebsocketExecuteOnSpec") + @ServerWebSocket("/echo/async") + @ExecuteOn(TaskExecutors.BLOCKING) + static class AsyncEchoServerWebSocket { + public static final String JOINED = "joined on thread %s" + public static final String DISCONNECTED = "disconnected on thread %s" + public static final String ECHO = " from thread %s" + + @Inject + WebSocketBroadcaster broadcaster + + Supplier formatMessage(String message) { + () -> message.formatted(Thread.currentThread().getName()) + } + + @OnOpen + Future onOpen(WebSocketSession session) { + Mono.fromSupplier(formatMessage(JOINED)) + .flatMap(message -> Mono.from(broadcaster.broadcast(message))).toFuture(); + } + + @OnMessage + Future onMessage(String message, WebSocketSession session) { + Mono.fromSupplier(formatMessage(message + ECHO)) + .flatMap(m -> Mono.from(broadcaster.broadcast(m))).toFuture() + } + + @OnClose + Future onClose(WebSocketSession session) { + Mono.just(session) + .flatMap(s -> { + LOG.info(DISCONNECTED.formatted(Thread.currentThread().getName())) + return Mono.just("closed") + }).toFuture() + } + } + + @Requires(property = "spec.name", value = "WebsocketExecuteOnSpec") + @ClientWebSocket + static abstract class EchoClientWebSocket implements AutoCloseable { + + static final String RECEIVED = "RECEIVED:" + + private WebSocketSession session + private List replies = new ArrayList<>() + + @OnOpen + void onOpen(WebSocketSession session) { + this.session = session + } + List getReplies() { + return replies + } + + @OnMessage + void onMessage(String message) { + replies.add(RECEIVED + message) + } + + abstract void send(String message) + + List receivedMessages() { + return filterMessagesByType(RECEIVED) + } + + List filterMessagesByType(String type) { + replies.stream() + .filter(str -> str.contains(type)) + .map(str -> str.replaceAll(type, "")) + .map(str -> str.substring(0, str.length()-(1)).replace("-thread-", "")) + .collect(Collectors.toList()) + } + } +}