Skip to content

Commit

Permalink
Add reply mode onSession to @SendToUser
Browse files Browse the repository at this point in the history
Added the ability to target a particular user session when a message
passes through the broker.  Given a user has two tabs open and the client
sends a message to the server from tab 1, it is now possible to reply only to tab 1
instead of the default reply to all semantics.

Issue: SPR-11506
  • Loading branch information
cloudmark committed Mar 4, 2014
1 parent 035d9d5 commit 8f9e159
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 5 deletions.
Expand Up @@ -46,9 +46,12 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {

public static final String SESSION_ID_HEADER = "simpSessionId";

public static final String SESSION_ID_FILTERED = "simpSessionFilter";

public static final String SUBSCRIPTION_ID_HEADER = "simpSubscriptionId";

public static final String USER_HEADER = "simpUser";



/**
Expand Down Expand Up @@ -127,6 +130,15 @@ public void setSessionId(String sessionId) {
setHeader(SESSION_ID_HEADER, sessionId);
}

public Boolean isSessionFiltered() {
return (Boolean) getHeader(SESSION_ID_FILTERED);
}

public void setSessionIdFiltered(Boolean sessionIdFiltered) {
setHeader(SESSION_ID_FILTERED, sessionIdFiltered);
}


public Principal getUser() {
return (Principal) getHeader(USER_HEADER);
}
Expand Down
Expand Up @@ -48,4 +48,10 @@
*/
String[] value() default {};

/**
* A flag indicating whether the message is to be sent to a particular user session.
*
*/
boolean onSession() default false;

}
Expand Up @@ -28,6 +28,7 @@
import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageSendingOperations;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.support.MessageBuilder;
Expand Down Expand Up @@ -120,9 +121,9 @@ public void handleReturnValue(Object returnValue, MethodParameter returnType, Me
}

SimpMessageHeaderAccessor inputHeaders = SimpMessageHeaderAccessor.wrap(inputMessage);

String sessionId = inputHeaders.getSessionId();
MessagePostProcessor postProcessor = new SessionHeaderPostProcessor(sessionId);
SendToUserHeaderPostProcessor sendToUserHeaderPostProcessor = new SendToUserHeaderPostProcessor(sessionId, inputHeaders.getSubscriptionId());

SendToUser sendToUser = returnType.getMethodAnnotation(SendToUser.class);
if (sendToUser != null) {
Expand All @@ -136,9 +137,11 @@ public void handleReturnValue(Object returnValue, MethodParameter returnType, Me
}
String[] destinations = getTargetDestinations(sendToUser, inputHeaders, this.defaultUserDestinationPrefix);
for (String destination : destinations) {
this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor);
}
return;
if(sendToUser.onSession())
this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, sendToUserHeaderPostProcessor);
else
this.messagingTemplate.convertAndSendToUser(userName, destination, returnValue, postProcessor);
}
}
else {
SendTo sendTo = returnType.getMethodAnnotation(SendTo.class);
Expand Down Expand Up @@ -178,6 +181,29 @@ public Message<?> postProcessMessage(Message<?> message) {
}
}

private final class SendToUserHeaderPostProcessor implements MessagePostProcessor {

private final String sessionId;

private final String subscriptionId;


public SendToUserHeaderPostProcessor(String sessionId, String subscriptionId) {
this.sessionId = sessionId;
this.subscriptionId = subscriptionId;
}

@Override
public Message<?> postProcessMessage(Message<?> message) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
headers.setSessionId(this.sessionId);
headers.setSubscriptionId(this.subscriptionId);
headers.setSessionIdFiltered(true);
headers.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
return MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
}
}

@Override
public String toString() {
return "SendToMethodReturnValueHandler [annotationRequired=" + annotationRequired + "]";
Expand Down
Expand Up @@ -129,7 +129,32 @@ else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
this.subscriptionRegistry.unregisterSubscription(message);
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
sendMessageToSubscribers(headers.getDestination(), message);
if (headers.isSessionFiltered() != null && headers.isSessionFiltered()) {
// If this is a message then we would not know the subscription Identifier.
if (headers.getSubscriptionId() == null) {
MultiValueMap<String,String> subscriptions = this.subscriptionRegistry.findSubscriptions(message);
if (headers.isSessionFiltered() != null && headers.isSessionFiltered()) {
if (subscriptions.get(headers.getSessionId()) != null ){
for (String subscriptionId : subscriptions.get(headers.getSessionId())) {
headers.setSubscriptionId(subscriptionId);
Object payload = message.getPayload();
Message<?> clientMessage = MessageBuilder.withPayload(payload).setHeaders(headers).build();
try {
this.clientOutboundChannel.send(clientMessage);
}
catch (Throwable ex) {
logger.error("Failed to send message to destination=" + destination +
", sessionId=" + headers.getSessionId() + ", subscriptionId=" + subscriptionId, ex);
}
}
}
}
} else {
this.clientOutboundChannel.send(message);
}
}
else
sendMessageToSubscribers(headers.getDestination(), message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
String sessionId = headers.getSessionId();
Expand Down
Expand Up @@ -69,6 +69,7 @@ public class SendToMethodReturnValueHandlerTests {
private MethodParameter sendToDefaultDestReturnType;
private MethodParameter sendToUserReturnType;
private MethodParameter sendToUserDefaultDestReturnType;
private MethodParameter sendToUserOnSessionReturnType;


@Before
Expand Down Expand Up @@ -98,6 +99,10 @@ public void setup() throws Exception {
method = this.getClass().getDeclaredMethod("handleAndSendToUser");
this.sendToUserReturnType = new MethodParameter(method, -1);

method = this.getClass().getDeclaredMethod("handleAndSendToUserOnSession");
this.sendToUserOnSessionReturnType = new MethodParameter(method, -1);


method = this.getClass().getDeclaredMethod("handleAndSendToUserDefaultDestination");
this.sendToUserDefaultDestReturnType = new MethodParameter(method, -1);
}
Expand Down Expand Up @@ -125,6 +130,7 @@ public void sendToNoAnnotations() throws Exception {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
assertEquals("sess1", headers.getSessionId());
assertNull(headers.getSubscriptionId());
assertNull(headers.isSessionFiltered());
assertEquals("/topic/dest", headers.getDestination());
}

Expand All @@ -149,6 +155,7 @@ public void sendTo() throws Exception {
headers = SimpMessageHeaderAccessor.wrap(message);
assertEquals(sessionId, headers.getSessionId());
assertNull(headers.getSubscriptionId());
assertNull(headers.isSessionFiltered());
assertEquals("/dest2", headers.getDestination());
}

Expand All @@ -167,6 +174,7 @@ public void sendToDefaultDestination() throws Exception {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
assertEquals(sessionId, headers.getSessionId());
assertNull(headers.getSubscriptionId());
assertNull(headers.isSessionFiltered());
assertEquals("/topic/dest", headers.getDestination());
}

Expand All @@ -192,6 +200,7 @@ public void sendToUser() throws Exception {
headers = SimpMessageHeaderAccessor.wrap(message);
assertEquals(sessionId, headers.getSessionId());
assertNull(headers.getSubscriptionId());
assertNull(headers.isSessionFiltered());
assertEquals("/user/" + user.getName() + "/dest2", headers.getDestination());
}

Expand All @@ -208,9 +217,12 @@ public void sendToUserWithUserNameProvider() throws Exception {
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());

SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(this.messageCaptor.getAllValues().get(0));
assertNull(headers.getSubscriptionId());
assertEquals("/user/Me myself and I/dest1", headers.getDestination());

headers = SimpMessageHeaderAccessor.wrap(this.messageCaptor.getAllValues().get(1));
assertNull(headers.getSubscriptionId());
assertNull(headers.isSessionFiltered());
assertEquals("/user/Me myself and I/dest2", headers.getDestination());
}

Expand All @@ -230,9 +242,36 @@ public void sendToUserDefaultDestination() throws Exception {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
assertEquals(sessionId, headers.getSessionId());
assertNull(headers.getSubscriptionId());
assertNull(headers.isSessionFiltered());
assertEquals("/user/" + user.getName() + "/queue/dest", headers.getDestination());
}


@Test
public void sendToUserParticularSession() throws Exception {

when(this.messageChannel.send(any(Message.class))).thenReturn(true);

String sessionId = "sess1";
String subscriptionId = "subs1";
String destination = "/dest";
TestUser user = new TestUser();
Message<?> inputMessage = createInputMessage(sessionId, subscriptionId, destination, user);

this.handler.handleReturnValue(payloadContent, this.sendToUserOnSessionReturnType, inputMessage);

verify(this.messageChannel).send(this.messageCaptor.capture());
assertNotNull(this.messageCaptor.getValue());

verify(this.messageChannel, times(1)).send(this.messageCaptor.capture());

Message<?> message = this.messageCaptor.getAllValues().get(0);
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
assertEquals("sessionId should always be copied", sessionId, headers.getSessionId());
assertEquals(subscriptionId, headers.getSubscriptionId());
assertTrue(headers.isSessionFiltered());
assertEquals("/user/" + user.getName() + "/dest1", headers.getDestination());
}

private Message<?> createInputMessage(String sessId, String subsId, String destination, Principal principal) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();
Expand Down Expand Up @@ -285,6 +324,11 @@ public String handleAndSendToUserDefaultDestination() {
return payloadContent;
}

@SendToUser(value={"/dest1"}, onSession=true)
public String handleAndSendToUserOnSession() {
return payloadContent;
}

@SendToUser({"/dest1", "/dest2"})
public String handleAndSendToUser() {
return payloadContent;
Expand Down
Expand Up @@ -89,6 +89,27 @@ public void subcribePublish() {
assertCapturedMessage("sess1", "sub3", "/bar");
assertCapturedMessage("sess2", "sub3", "/bar");
}

@Test
public void subcribePublishOnSession() {

this.messageHandler.start();

this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub1", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub2", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess1", "sub3", "/bar"));

this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub1", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub2", "/foo"));
this.messageHandler.handleMessage(createSubscriptionMessage("sess2", "sub3", "/bar"));

this.messageHandler.handleMessage(createMessageWithSubscriptionIdAndSessionId("/foo", "message1", "sess1", "sub1"));
this.messageHandler.handleMessage(createMessageWithSubscriptionIdAndSessionId("/bar", "message1", "sess2", "sub3"));

verify(this.clientOutboundChannel, times(2)).send(this.messageCaptor.capture());
assertCapturedMessage("sess1", "sub1", "/foo");
assertCapturedMessage("sess2", "sub3", "/bar");
}

@Test
public void subcribeDisconnectPublish() {
Expand Down Expand Up @@ -164,6 +185,16 @@ protected Message<String> createMessage(String destination, String payload) {
return MessageBuilder.withPayload(payload).copyHeaders(headers.toMap()).build();
}

protected Message<String> createMessageWithSubscriptionIdAndSessionId(String destination, String payload, String sessionId, String subscriptionId) {

SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
headers.setDestination(destination);
headers.setSessionId(sessionId);
headers.setSubscriptionId(subscriptionId);
headers.setSessionIdFiltered(Boolean.TRUE);
return MessageBuilder.withPayload(payload).copyHeaders(headers.toMap()).build();
}

protected boolean assertCapturedMessage(String sessionId, String subcriptionId, String destination) {
for (Message<?> message : this.messageCaptor.getAllValues()) {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
Expand Down

0 comments on commit 8f9e159

Please sign in to comment.