From 8f9e1591852e07bd3875690a17c5349bbe5d71b5 Mon Sep 17 00:00:00 2001 From: Mark Galea Date: Tue, 4 Mar 2014 02:39:03 +0100 Subject: [PATCH] Add reply mode onSession to @SendToUser Added the ability to target a particular user session when a message passes through the broker. Given a user has two tabs open and the client sends a message to the server from tab 1, it is now possible to reply only to tab 1 instead of the default reply to all semantics. Issue: SPR-11506 --- .../simp/SimpMessageHeaderAccessor.java | 12 +++++ .../messaging/simp/annotation/SendToUser.java | 6 +++ .../SendToMethodReturnValueHandler.java | 34 ++++++++++++-- .../broker/SimpleBrokerMessageHandler.java | 27 +++++++++++- .../SendToMethodReturnValueHandlerTests.java | 44 +++++++++++++++++++ .../SimpleBrokerMessageHandlerTests.java | 31 +++++++++++++ 6 files changed, 149 insertions(+), 5 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java index 69d26190b2f0..aa9a241aec80 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java @@ -46,9 +46,12 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { public static final String SESSION_ID_HEADER = "simpSessionId"; + public static final String SESSION_ID_FILTERED = "simpSessionFilter"; + public static final String SUBSCRIPTION_ID_HEADER = "simpSubscriptionId"; public static final String USER_HEADER = "simpUser"; + /** @@ -127,6 +130,15 @@ public void setSessionId(String sessionId) { setHeader(SESSION_ID_HEADER, sessionId); } + public Boolean isSessionFiltered() { + return (Boolean) getHeader(SESSION_ID_FILTERED); + } + + public void setSessionIdFiltered(Boolean sessionIdFiltered) { + setHeader(SESSION_ID_FILTERED, sessionIdFiltered); + } + + public Principal getUser() { return (Principal) getHeader(USER_HEADER); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/SendToUser.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/SendToUser.java index d083baabf996..426fa82c6dd2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/SendToUser.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/SendToUser.java @@ -48,4 +48,10 @@ */ String[] value() default {}; + /** + * A flag indicating whether the message is to be sent to a particular user session. + * + */ + boolean onSession() default false; + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java index fe9246b4b2d7..60e52d85c264 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java @@ -28,6 +28,7 @@ import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageSendingOperations; +import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.support.MessageBuilder; @@ -120,9 +121,9 @@ public void handleReturnValue(Object returnValue, MethodParameter returnType, Me } SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage); - String sessionId = inputHeaders.getSessionId(); MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId); + SendToUserHeaderPostProcessor sendToUserHeaderPostProcessor = new SendToUserHeaderPostProcessor(sessionId, inputHeaders.getSubscriptionId()); SendToUser sendToUser = returnType.getMethodAnnotation(SendToUser.class); if (sendToUser != null) { @@ -136,9 +137,11 @@ public void handleReturnValue(Object returnValue, MethodParameter returnType, Me } String[] destinations = getTargetDestinations(sendToUser, inputHeaders, this.defaultUserDestinationPrefix); for (String destination : destinations) { - this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor); - } - return; + if(sendToUser.onSession()) + this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, sendToUserHeaderPostProcessor); + else + this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor); + } } else { SendTo sendTo = returnType.getMethodAnnotation(SendTo.class); @@ -178,6 +181,29 @@ public Message postProcessMessage(Message message) { } } + private final class SendToUserHeaderPostProcessor implements MessagePostProcessor { + + private final String sessionId; + + private final String subscriptionId; + + + public SendToUserHeaderPostProcessor(String sessionId, String subscriptionId) { + this.sessionId = sessionId; + this.subscriptionId = subscriptionId; + } + + @Override + public Message postProcessMessage(Message message) { + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + headers.setSessionId(this.sessionId); + headers.setSubscriptionId(this.subscriptionId); + headers.setSessionIdFiltered(true); + headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE); + return MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + } + } + @Override public String toString() { return "SendToMethodReturnValueHandler [annotationRequired=" + annotationRequired + "]"; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java index d0cc6312d697..68ead976516e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java @@ -129,7 +129,32 @@ else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) { this.subscriptionRegistry.unregisterSubscription(message); } else if (SimpMessageType.MESSAGE.equals(messageType)) { - sendMessageToSubscribers(headers.getDestination(), message); + if (headers.isSessionFiltered() != null && headers.isSessionFiltered()) { + // If this is a message then we would not know the subscription Identifier. + if (headers.getSubscriptionId() == null) { + MultiValueMap subscriptions = this.subscriptionRegistry.findSubscriptions(message); + if (headers.isSessionFiltered() != null && headers.isSessionFiltered()) { + if (subscriptions.get(headers.getSessionId()) != null ){ + for (String subscriptionId : subscriptions.get(headers.getSessionId())) { + headers.setSubscriptionId(subscriptionId); + Object payload = message.getPayload(); + Message clientMessage = MessageBuilder.withPayload(payload).setHeaders(headers).build(); + try { + this.clientOutboundChannel.send(clientMessage); + } + catch (Throwable ex) { + logger.error("Failed to send message to destination=" + destination + + ", sessionId=" + headers.getSessionId() + ", subscriptionId=" + subscriptionId, ex); + } + } + } + } + } else { + this.clientOutboundChannel.send(message); + } + } + else + sendMessageToSubscribers(headers.getDestination(), message); } else if (SimpMessageType.DISCONNECT.equals(messageType)) { String sessionId = headers.getSessionId(); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java index 84f71e14cd39..a762a7b718f6 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java @@ -69,6 +69,7 @@ public class SendToMethodReturnValueHandlerTests { private MethodParameter sendToDefaultDestReturnType; private MethodParameter sendToUserReturnType; private MethodParameter sendToUserDefaultDestReturnType; + private MethodParameter sendToUserOnSessionReturnType; @Before @@ -98,6 +99,10 @@ public void setup() throws Exception { method = this.getClass().getDeclaredMethod("handleAndSendToUser"); this.sendToUserReturnType = new MethodParameter(method, -1); + method = this.getClass().getDeclaredMethod("handleAndSendToUserOnSession"); + this.sendToUserOnSessionReturnType = new MethodParameter(method, -1); + + method = this.getClass().getDeclaredMethod("handleAndSendToUserDefaultDestination"); this.sendToUserDefaultDestReturnType = new MethodParameter(method, -1); } @@ -125,6 +130,7 @@ public void sendToNoAnnotations() throws Exception { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); assertEquals("sess1", headers.getSessionId()); assertNull(headers.getSubscriptionId()); + assertNull(headers.isSessionFiltered()); assertEquals("/topic/dest", headers.getDestination()); } @@ -149,6 +155,7 @@ public void sendTo() throws Exception { headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); assertNull(headers.getSubscriptionId()); + assertNull(headers.isSessionFiltered()); assertEquals("/dest2", headers.getDestination()); } @@ -167,6 +174,7 @@ public void sendToDefaultDestination() throws Exception { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); assertNull(headers.getSubscriptionId()); + assertNull(headers.isSessionFiltered()); assertEquals("/topic/dest", headers.getDestination()); } @@ -192,6 +200,7 @@ public void sendToUser() throws Exception { headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); assertNull(headers.getSubscriptionId()); + assertNull(headers.isSessionFiltered()); assertEquals("/user/" + user.getName() + "/dest2", headers.getDestination()); } @@ -208,9 +217,12 @@ public void sendToUserWithUserNameProvider() throws Exception { verify(this.messageChannel, times(2)).send(this.messageCaptor.capture()); SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(this.messageCaptor.getAllValues().get(0)); + assertNull(headers.getSubscriptionId()); assertEquals("/user/Me myself and I/dest1", headers.getDestination()); headers = SimpMessageHeaderAccessor.wrap(this.messageCaptor.getAllValues().get(1)); + assertNull(headers.getSubscriptionId()); + assertNull(headers.isSessionFiltered()); assertEquals("/user/Me myself and I/dest2", headers.getDestination()); } @@ -230,9 +242,36 @@ public void sendToUserDefaultDestination() throws Exception { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); assertEquals(sessionId, headers.getSessionId()); assertNull(headers.getSubscriptionId()); + assertNull(headers.isSessionFiltered()); assertEquals("/user/" + user.getName() + "/queue/dest", headers.getDestination()); } + + @Test + public void sendToUserParticularSession() throws Exception { + + when(this.messageChannel.send(any(Message.class))).thenReturn(true); + + String sessionId = "sess1"; + String subscriptionId = "subs1"; + String destination = "/dest"; + TestUser user = new TestUser(); + Message inputMessage = createInputMessage(sessionId, subscriptionId, destination, user); + + this.handler.handleReturnValue(payloadContent, this.sendToUserOnSessionReturnType, inputMessage); + + verify(this.messageChannel).send(this.messageCaptor.capture()); + assertNotNull(this.messageCaptor.getValue()); + + verify(this.messageChannel, times(1)).send(this.messageCaptor.capture()); + + Message message = this.messageCaptor.getAllValues().get(0); + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message); + assertEquals("sessionId should always be copied", sessionId, headers.getSessionId()); + assertEquals(subscriptionId, headers.getSubscriptionId()); + assertTrue(headers.isSessionFiltered()); + assertEquals("/user/" + user.getName() + "/dest1", headers.getDestination()); + } private Message createInputMessage(String sessId, String subsId, String destination, Principal principal) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); @@ -285,6 +324,11 @@ public String handleAndSendToUserDefaultDestination() { return payloadContent; } + @SendToUser(value={"/dest1"}, onSession=true) + public String handleAndSendToUserOnSession() { + return payloadContent; + } + @SendToUser({"/dest1", "/dest2"}) public String handleAndSendToUser() { return payloadContent; diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java index 7b15b2f49b9f..35b8fd6e6253 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java @@ -89,6 +89,27 @@ public void subcribePublish() { assertCapturedMessage("sess1", "sub3", "/bar"); assertCapturedMessage("sess2", "sub3", "/bar"); } + + @Test + public void subcribePublishOnSession() { + + this.messageHandler.start(); + + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub3", "/bar")); + + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub1", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub2", "/foo")); + this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub3", "/bar")); + + this.messageHandler.handleMessage(createMessageWithSubscriptionIdAndSessionId("/foo", "message1", "sess1", "sub1")); + this.messageHandler.handleMessage(createMessageWithSubscriptionIdAndSessionId("/bar", "message1", "sess2", "sub3")); + + verify(this.clientOutboundChannel, times(2)).send(this.messageCaptor.capture()); + assertCapturedMessage("sess1", "sub1", "/foo"); + assertCapturedMessage("sess2", "sub3", "/bar"); + } @Test public void subcribeDisconnectPublish() { @@ -164,6 +185,16 @@ protected Message createMessage(String destination, String payload) { return MessageBuilder.withPayload(payload).copyHeaders(headers.toMap()).build(); } + protected Message createMessageWithSubscriptionIdAndSessionId(String destination, String payload, String sessionId, String subscriptionId) { + + SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + headers.setDestination(destination); + headers.setSessionId(sessionId); + headers.setSubscriptionId(subscriptionId); + headers.setSessionIdFiltered(Boolean.TRUE); + return MessageBuilder.withPayload(payload).copyHeaders(headers.toMap()).build(); + } + protected boolean assertCapturedMessage(String sessionId, String subcriptionId, String destination) { for (Message message : this.messageCaptor.getAllValues()) { SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);