Skip to content

Commit

Permalink
Fixes #1146 - Review Oort/Seti channel usage
Browse files Browse the repository at this point in the history
Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
  • Loading branch information
sbordet committed Feb 25, 2022
1 parent 4969dca commit bb445a1
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,6 @@ The wildcard mechanism works also for listeners, so it is possible to listen to
cometd.addListener("/meta/*", function(message) { ... });
----

By default, subscriptions to the global wildcards `+/*+` and `+/**+` result in an error, but you can change this behavior by specifying a custom security policy on the Bayeux server.

[[_javascript_meta_channels]]
==== Meta Channel List

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.EventListener;
import java.util.EventObject;
Expand All @@ -38,18 +39,19 @@
import org.cometd.bayeux.ChannelId;
import org.cometd.bayeux.Message;
import org.cometd.bayeux.client.ClientSession;
import org.cometd.bayeux.server.Authorizer;
import org.cometd.bayeux.server.BayeuxServer;
import org.cometd.bayeux.server.BayeuxServer.Extension;
import org.cometd.bayeux.server.LocalSession;
import org.cometd.bayeux.server.ServerChannel;
import org.cometd.bayeux.server.ServerMessage;
import org.cometd.bayeux.server.ServerMessage.Mutable;
import org.cometd.bayeux.server.ServerSession;
import org.cometd.client.ext.AckExtension;
import org.cometd.client.http.jetty.JettyHttpClientTransport;
import org.cometd.client.transport.ClientTransport;
import org.cometd.client.websocket.javax.WebSocketTransport;
import org.cometd.common.JSONContext;
import org.cometd.server.authorizer.GrantAuthorizer;
import org.cometd.server.ext.AcknowledgedMessagesExtension;
import org.cometd.server.ext.BinaryExtension;
import org.eclipse.jetty.client.HttpClient;
Expand Down Expand Up @@ -90,11 +92,14 @@ public class Oort extends ContainerLifeCycle {
public static final String OORT_CLOUD_CHANNEL = "/oort/cloud";
public static final String OORT_SERVICE_CHANNEL = "/service/oort";
static final String COMET_URL_ATTRIBUTE = EXT_OORT_FIELD + "." + EXT_COMET_URL_FIELD;
private static final List<String> PROTECTED_CHANNELS = Arrays.asList("/oort/**", "/oort/*", "/service/oort/**", "/service/oort/*", "/service/oort");

private final ConcurrentMap<String, Boolean> _channels = new ConcurrentHashMap<>();
private final CopyOnWriteArrayList<CometListener> _cometListeners = new CopyOnWriteArrayList<>();
private final ServerChannel.MessageListener _cloudListener = new CloudListener();
private final List<ClientTransport.Factory> _transportFactories = new ArrayList<>();
private final BayeuxServer.SubscriptionListener _allChannelsFilter = new AllChannelsFilter();
private final OortAuthorizer _authorizer = new OortAuthorizer();
private final BayeuxServer _bayeux;
private final String _url;
private final String _id;
Expand Down Expand Up @@ -163,28 +168,34 @@ protected void doStart() throws Exception {
}
}

_bayeux.addListener(_allChannelsFilter);

ServerChannel oortCloudChannel = _bayeux.createChannelIfAbsent(OORT_CLOUD_CHANNEL).getReference();
oortCloudChannel.addAuthorizer(GrantAuthorizer.GRANT_ALL);
oortCloudChannel.addListener(_cloudListener);

_oortSession.handshake();

protectOortChannels(_bayeux);

super.doStart();
}

@Override
protected void doStop() throws Exception {
super.doStop();

unprotectOortChannels(_bayeux);

_oortSession.disconnect();
_oortSession.removeExtension(_binaryExtension);

ServerChannel channel = _bayeux.getChannel(OORT_CLOUD_CHANNEL);
if (channel != null) {
channel.removeListener(_cloudListener);
channel.removeAuthorizer(GrantAuthorizer.GRANT_ALL);
}

_bayeux.removeListener(_allChannelsFilter);

Extension binaryExtension = _serverBinaryExtension;
_serverBinaryExtension = null;
if (binaryExtension != null) {
Expand All @@ -206,6 +217,19 @@ protected void doStop() throws Exception {
}
}

protected void protectOortChannels(BayeuxServer bayeux) {
PROTECTED_CHANNELS.forEach(name -> bayeux.createChannelIfAbsent(name, channel -> channel.addAuthorizer(_authorizer)));
}

protected void unprotectOortChannels(BayeuxServer bayeux) {
PROTECTED_CHANNELS.forEach(name -> {
ServerChannel channel = bayeux.getChannel(name);
if (channel != null) {
channel.removeAuthorizer(_authorizer);
}
});
}

protected ScheduledExecutorService getScheduler() {
return _scheduler;
}
Expand Down Expand Up @@ -703,4 +727,35 @@ public String getCometURL() {
}
}
}

private class AllChannelsFilter implements BayeuxServer.SubscriptionListener, ServerSession.MessageListener {
@Override
public void subscribed(ServerSession session, ServerChannel channel, ServerMessage message) {
if ("/**".equals(channel.getId()) && !session.isLocalSession()) {
session.addListener(this);
}
}

@Override
public boolean onMessage(ServerSession session, ServerSession sender, ServerMessage message) {
// Don't send Oort messages to the subscribers of the /** channel.
if (message.getChannel().startsWith("/oort/")) {
if (_logger.isDebugEnabled()) {
_logger.debug("Dropping Oort message {} to channel '/**' subscriber {}", message, session);
}
return false;
}
return true;
}
}

private class OortAuthorizer implements Authorizer {
@Override
public Result authorize(Operation operation, ChannelId channel, ServerSession session, ServerMessage message) {
if (session.isLocalSession() || isOort(session)) {
return Result.grant();
}
return Result.ignore();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.io.IOException;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EventListener;
Expand All @@ -31,9 +32,11 @@
import java.util.TreeMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import org.cometd.bayeux.ChannelId;
import org.cometd.bayeux.Message;
import org.cometd.bayeux.Promise;
import org.cometd.bayeux.client.ClientSessionChannel;
import org.cometd.bayeux.server.Authorizer;
import org.cometd.bayeux.server.BayeuxServer;
import org.cometd.bayeux.server.LocalSession;
import org.cometd.bayeux.server.SecurityPolicy;
Expand Down Expand Up @@ -76,11 +79,14 @@
public class Seti extends AbstractLifeCycle implements Dumpable {
public static final String SETI_ATTRIBUTE = Seti.class.getName();
private static final String SETI_ALL_CHANNEL = "/seti/all";
private static final List<String> PROTECTED_CHANNELS = Arrays.asList("/seti/**", "/seti/*");

private final Map<String, Set<Location>> _uid2Location = new HashMap<>();
private final List<PresenceListener> _presenceListeners = new CopyOnWriteArrayList<>();
private final Oort.CometListener _cometListener = new CometListener();
private final ServerChannel.SubscriptionListener _initialStateListener = new InitialStateListener();
private final BayeuxServer.SubscriptionListener _allChannelsFilter = new AllChannelsFilter();
private final Authorizer _authorizer = new SetiAuthorizer();
private final Oort _oort;
private final String _setiId;
private final Logger _logger;
Expand All @@ -107,8 +113,12 @@ public String getId() {
protected void doStart() {
BayeuxServer bayeux = _oort.getBayeuxServer();

bayeux.addListener(_allChannelsFilter);

_session.handshake();

protectSetiChannels(bayeux);

ServerChannel setiAllChannel = bayeux.createChannelIfAbsent(SETI_ALL_CHANNEL).getReference();
setiAllChannel.addListener(_initialStateListener);
_session.getChannel(SETI_ALL_CHANNEL).subscribe((channel, message) -> receiveBroadcast(message));
Expand All @@ -127,21 +137,40 @@ protected void doStart() {

@Override
protected void doStop() {
BayeuxServer bayeux = _oort.getBayeuxServer();

removeAssociationsAndPresences();
_presenceListeners.clear();

_session.disconnect();

_oort.removeCometListener(_cometListener);

String setiChannelName = generateSetiChannel(_setiId);
_oort.deobserveChannel(setiChannelName);

_oort.deobserveChannel(SETI_ALL_CHANNEL);
ServerChannel setiAllChannel = _oort.getBayeuxServer().getChannel(SETI_ALL_CHANNEL);
ServerChannel setiAllChannel = bayeux.getChannel(SETI_ALL_CHANNEL);
if (setiAllChannel != null) {
setiAllChannel.removeListener(_initialStateListener);
}

unprotectSetiChannels(bayeux);

_session.disconnect();

bayeux.removeListener(_allChannelsFilter);
}

protected void protectSetiChannels(BayeuxServer bayeux) {
PROTECTED_CHANNELS.forEach(name -> bayeux.createChannelIfAbsent(name, channel -> channel.addAuthorizer(_authorizer)));
}

protected void unprotectSetiChannels(BayeuxServer bayeux) {
PROTECTED_CHANNELS.forEach(name -> {
ServerChannel channel = bayeux.getChannel(name);
if (channel != null) {
channel.removeAuthorizer(_authorizer);
}
});
}

protected String generateSetiId(String oortURL) {
Expand Down Expand Up @@ -1045,4 +1074,35 @@ public void subscribed(ServerSession session, ServerChannel channel, ServerMessa
}
}
}

private class AllChannelsFilter implements BayeuxServer.SubscriptionListener, ServerSession.MessageListener {
@Override
public void subscribed(ServerSession session, ServerChannel channel, ServerMessage message) {
if ("/**".equals(channel.getId()) && !session.isLocalSession()) {
session.addListener(this);
}
}

@Override
public boolean onMessage(ServerSession session, ServerSession sender, ServerMessage message) {
// Don't send Seti messages to the subscribers of the /** channel.
if (message.getChannel().startsWith("/seti/")) {
if (_logger.isDebugEnabled()) {
_logger.debug("Dropping Seti message {} to channel '/**' subscriber {}", message, session);
}
return false;
}
return true;
}
}

private class SetiAuthorizer implements Authorizer {
@Override
public Result authorize(Operation operation, ChannelId channel, ServerSession session, ServerMessage message) {
if (session.isLocalSession() || getOort().isOort(session)) {
return Result.grant();
}
return Result.ignore();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,69 @@ public void cometLeft(Event event) {
Assertions.assertEquals(1, joinCount.get());
}

@ParameterizedTest
@MethodSource("transports")
public void testProtectedOortChannels(String serverTransport) throws Exception {
Server server1 = startServer(serverTransport, 0);
Oort oort1 = startOort(server1);

BayeuxClient client = startClient(oort1, null);
CountDownLatch subscribeLatch = new CountDownLatch(1);
client.getChannel("/oort/*").subscribe((channel, message) -> {}, message -> {
// Must not be able to subscribe.
if (!message.isSuccessful()) {
subscribeLatch.countDown();
}
});
Assertions.assertTrue(subscribeLatch.await(1, TimeUnit.SECONDS));

CountDownLatch publishLatch1 = new CountDownLatch(1);
client.getChannel("/oort/cloud").publish("data1", message -> {
// Must not be able to publish.
if (!message.isSuccessful()) {
publishLatch1.countDown();
}
});
Assertions.assertTrue(publishLatch1.await(1, TimeUnit.SECONDS));

CountDownLatch publishLatch2 = new CountDownLatch(1);
client.getChannel("/service/oort").publish("data2", message -> {
// Must not be able to publish.
if (!message.isSuccessful()) {
publishLatch2.countDown();
}
});
Assertions.assertTrue(publishLatch2.await(1, TimeUnit.SECONDS));

String broadcastChannel = "/broadcast";
CountDownLatch allMessageLatch = new CountDownLatch(1);
CountDownLatch allSubscribeLatch = new CountDownLatch(1);
client.getChannel("/**").subscribe((channel, message) -> {
String channelName = message.getChannel();
if (channelName.startsWith("/oort") || channelName.equals(broadcastChannel)) {
allMessageLatch.countDown();
}
}, message -> {
if (message.isSuccessful()) {
allSubscribeLatch.countDown();
}
});

Assertions.assertTrue(allSubscribeLatch.await(5, TimeUnit.SECONDS));

// Cause an Oort message to be broadcast.
Server server2 = startServer(serverTransport, 0);
Oort oort2 = startOort(server2);
oort1.observeComet(oort2.getURL());

// Make sure it was not received.
Assertions.assertFalse(allMessageLatch.await(1, TimeUnit.SECONDS));

// Publish a non-Oort message, make sure it's received.
client.getChannel(broadcastChannel).publish("hello");
Assertions.assertTrue(allMessageLatch.await(5, TimeUnit.SECONDS));
}

private void sleep(long time) {
try {
Thread.sleep(time);
Expand Down

0 comments on commit bb445a1

Please sign in to comment.